diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 57341ddc..03945e25 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,8 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.9", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + sqlalchemy-version: ["1.4", "2.0"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -18,10 +19,15 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + allow-prereleases: true cache: pip cache-dependency-path: pyproject.toml - - name: Install dependencies - run: pip install -e .[test] coveralls + - name: Install dependencies SQLAlchemy 1.4 + if: matrix.sqlalchemy-version == 1.4 + run: pip install -e .[test,sqlmodel] coveralls SQLAlchemy==1.4.* + - name: Install dependencies SQLAlchemy 2.0 + if: matrix.sqlalchemy-version == 2.0 + run: pip install -e .[test] coveralls SQLAlchemy==2.0.* - name: Test with pytest run: coverage run -m pytest - name: Upload Coverage diff --git a/CHANGES.rst b/CHANGES.rst index 17d3bbd4..ce5a7022 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,10 @@ Version history =============== +**3.0.0rc3** + +- Added support for SQLAlchemy 2 (PR by rbuffat with help from mhauru) + **3.0.0rc2** - Added support for generating SQLModel classes (PR by Andrii Khirilov) diff --git a/README.rst b/README.rst index 50f04496..37e6c42e 100644 --- a/README.rst +++ b/README.rst @@ -18,7 +18,7 @@ latest SQLAlchemy version). Features ======== -* Supports SQLAlchemy 1.4.x +* Supports SQLAlchemy 1.4.x and 2 * Produces declarative code that almost looks like it was hand written * Produces `PEP 8`_ compliant code * Accurately determines relationships, including many-to-many, one-to-one diff --git a/pyproject.toml b/pyproject.toml index 3e5b169e..d3d5982d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,12 +25,14 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", ] requires-python = ">=3.7" dependencies = [ - "SQLAlchemy >= 1.4.36, < 2.0", + "SQLAlchemy >= 1.4.36", "inflect >= 4.0.0", "importlib_metadata; python_version < '3.10'", + "greenlet >= 3.0.0a1; python_version >= '3.12'", ] dynamic = ["version"] @@ -44,6 +46,8 @@ test = [ "pytest-cov", "psycopg2-binary", "mysql-connector-python", +] +sqlmodel = [ "sqlmodel", ] citext = ["sqlalchemy-citext >= 1.7.0"] @@ -88,7 +92,7 @@ show_missing = true [tool.tox] legacy_tox_ini = """ [tox] -envlist = py37, py38, py39, py310 +envlist = py37, py38, py39, py310, py311, py312 skip_missing_interpreters = true isolated_build = true diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index e3b3c471..9ddb11cf 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -75,6 +75,23 @@ _re_invalid_identifier = re.compile(r"(?u)\W") +@dataclass +class LiteralImport: + pkgname: str + name: str + + +@dataclass +class Base: + """Representation of MetaData for Tables, respectively Base for classes""" + + literal_imports: list[LiteralImport] + declarations: list[str] + metadata_ref: str + decorator: str | None = None + table_metadata_declaration: str | None = None + + class CodeGenerator(metaclass=ABCMeta): valid_options: ClassVar[set[str]] = set() @@ -116,8 +133,18 @@ def __init__( super().__init__(metadata, bind, options) self.indentation: str = indentation self.imports: dict[str, set[str]] = defaultdict(set) + self.module_imports: set[str] = set() + + def generate_base(self) -> None: + self.base = Base( + literal_imports=[LiteralImport("sqlalchemy", "MetaData")], + declarations=["metadata = MetaData()"], + metadata_ref="metadata", + ) def generate(self) -> str: + self.generate_base() + sections: list[str] = [] # Remove unwanted elements from the metadata @@ -166,7 +193,9 @@ def generate(self) -> str: return "\n\n".join(sections) + "\n" def collect_imports(self, models: Iterable[Model]) -> None: - self.add_import(MetaData) + for literal_import in self.base.literal_imports: + self.add_literal_import(literal_import.pkgname, literal_import.name) + for model in models: self.collect_imports_for_model(model) @@ -184,7 +213,6 @@ def collect_imports_for_model(self, model: Model) -> None: self.collect_imports_for_constraint(index) def collect_imports_for_column(self, column: Column[Any]) -> None: - self.add_import(Column) self.add_import(column.type) if isinstance(column.type, ARRAY): @@ -254,7 +282,11 @@ def add_literal_import(self, pkgname: str, name: str) -> None: def remove_literal_import(self, pkgname: str, name: str) -> None: names = self.imports.setdefault(pkgname, set()) - names.remove(name) + if name in names: + names.remove(name) + + def add_module_import(self, pgkname: str) -> None: + self.module_imports.add(pgkname) def group_imports(self) -> list[list[str]]: future_imports: list[str] = [] @@ -274,6 +306,9 @@ def group_imports(self) -> list[list[str]]: collection.append(f"from {package} import {imports}") + for module in sorted(self.module_imports): + thirdparty_imports.append(f"import {module}") + return [ group for group in (future_imports, stdlib_imports, thirdparty_imports) @@ -301,10 +336,16 @@ def generate_model_name(self, model: Model, global_names: set[str]) -> None: model.name = self.find_free_name(preferred_name, global_names) def render_module_variables(self, models: list[Model]) -> str: - return "metadata = MetaData()" + declarations = self.base.declarations + + if any(not isinstance(model, ModelClass) for model in models): + if self.base.table_metadata_declaration is not None: + declarations.append(self.base.table_metadata_declaration) + + return "\n".join(declarations) def render_models(self, models: list[Model]) -> str: - rendered = [] + rendered: list[str] = [] for model in models: rendered_table = self.render_table(model.table) rendered.append(f"{model.name} = {rendered_table}") @@ -312,12 +353,12 @@ def render_models(self, models: list[Model]) -> str: return "\n\n".join(rendered) def render_table(self, table: Table) -> str: - args: list[str] = [f"{table.name!r}, metadata"] + args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"] kwargs: dict[str, object] = {} for column in table.columns: # Cast is required because of a bug in the SQLAlchemy stubs regarding # Table.columns - args.append(self.render_column(column, True)) + args.append(self.render_column(column, True, is_table=True)) for constraint in sorted(table.constraints, key=get_constraint_sort_key): if uses_default_name(constraint): @@ -351,7 +392,10 @@ def render_index(self, index: Index) -> str: return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs) - def render_column(self, column: Column[Any], show_name: bool) -> str: + # TODO find better solution for is_table + def render_column( + self, column: Column[Any], show_name: bool, is_table: bool = False + ) -> str: args = [] kwargs: dict[str, Any] = {} kwarg = [] @@ -373,11 +417,14 @@ def render_column(self, column: Column[Any], show_name: bool) -> str: i.unique and set(i.columns) == {column} and uses_default_name(i) for i in column.table.indexes ) - is_primary = any( - isinstance(c, PrimaryKeyConstraint) - and column.name in c.columns - and uses_default_name(c) - for c in column.table.constraints + is_primary = ( + any( + isinstance(c, PrimaryKeyConstraint) + and column.name in c.columns + and uses_default_name(c) + for c in column.table.constraints + ) + or column.primary_key ) has_index = any( set(i.columns) == {column} and uses_default_name(i) @@ -402,7 +449,11 @@ def render_column(self, column: Column[Any], show_name: bool) -> str: kwargs["key"] = column.key if is_primary: kwargs["primary_key"] = True - if not column.nullable and not is_sole_pk: + if ( + not column.nullable + and not is_sole_pk + and (_sqla_version < (2, 0) or is_table) + ): kwargs["nullable"] = False if is_unique: @@ -436,7 +487,11 @@ def render_column(self, column: Column[Any], show_name: bool) -> str: if comment: kwargs["comment"] = repr(comment) - return render_callable("Column", *args, kwargs=kwargs) + if _sqla_version < (2, 0) or is_table: + self.add_import(Column) + return render_callable("Column", *args, kwargs=kwargs) + else: + return render_callable("mapped_column", *args, kwargs=kwargs) def render_column_type(self, coltype: object) -> str: args = [] @@ -678,16 +733,41 @@ def __init__( self.base_class_name: str = base_class_name self.inflect_engine = inflect.engine() + def generate_base(self) -> None: + if _sqla_version < (1, 4): + table_decoration = f"metadata = {self.base_class_name}.metadata" + self.base = Base( + literal_imports=[ + LiteralImport("sqlalchemy.ext.declarative", "declarative_base") + ], + declarations=[f"{self.base_class_name} = declarative_base()"], + metadata_ref=self.base_class_name, + table_metadata_declaration=table_decoration, + ) + elif (1, 4) <= _sqla_version < (2, 0): + table_decoration = f"metadata = {self.base_class_name}.metadata" + self.base = Base( + literal_imports=[LiteralImport("sqlalchemy.orm", "declarative_base")], + declarations=[f"{self.base_class_name} = declarative_base()"], + metadata_ref="metadata", + table_metadata_declaration=table_decoration, + ) + else: + self.base = Base( + literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")], + declarations=[ + f"class {self.base_class_name}(DeclarativeBase):", + f"{self.indentation}pass", + ], + metadata_ref=f"{self.base_class_name}.metadata", + ) + def collect_imports(self, models: Iterable[Model]) -> None: super().collect_imports(models) if any(isinstance(model, ModelClass) for model in models): - self.remove_literal_import("sqlalchemy", "MetaData") - if _sqla_version < (1, 4): - self.add_literal_import( - "sqlalchemy.ext.declarative", "declarative_base" - ) - else: - self.add_literal_import("sqlalchemy.orm", "declarative_base") + if _sqla_version >= (2, 0): + self.add_literal_import("sqlalchemy.orm", "Mapped") + self.add_literal_import("sqlalchemy.orm", "mapped_column") def collect_imports_for_model(self, model: Model) -> None: super().collect_imports_for_model(model) @@ -753,6 +833,12 @@ def generate_models(self) -> list[Model]: model.parent_class = target target.children.append(model) + # Change base if we only have tables + if not any( + isinstance(model, ModelClass) for model in models_by_table_name.values() + ): + super().generate_base() + # Collect the imports self.collect_imports(models_by_table_name.values()) @@ -974,6 +1060,7 @@ def generate_relationship_name( local_names: set[str], ) -> None: # Self referential reverse relationships + preferred_name: str if ( relationship.type in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE) @@ -1014,18 +1101,8 @@ def generate_relationship_name( preferred_name, global_names, local_names ) - def render_module_variables(self, models: list[Model]) -> str: - if not any(isinstance(model, ModelClass) for model in models): - return super().render_module_variables(models) - - declarations = [f"{self.base_class_name} = declarative_base()"] - if any(not isinstance(model, ModelClass) for model in models): - declarations.append(f"metadata = {self.base_class_name}.metadata") - - return "\n".join(declarations) - def render_models(self, models: list[Model]) -> str: - rendered = [] + rendered: list[str] = [] for model in models: if isinstance(model, ModelClass): rendered.append(self.render_class(model)) @@ -1132,7 +1209,29 @@ def render_table_args(self, table: Table) -> str: def render_column_attribute(self, column_attr: ColumnAttribute) -> str: column = column_attr.column rendered_column = self.render_column(column, column_attr.name != column.name) - return f"{column_attr.name} = {rendered_column}" + + if _sqla_version < (2, 0): + return f"{column_attr.name} = {rendered_column}" + else: + try: + python_type = column.type.python_type + python_type_name = python_type.__name__ + if python_type.__module__ == "builtins": + column_python_type = python_type_name + else: + python_type_module = python_type.__module__ + column_python_type = f"{python_type_module}.{python_type_name}" + self.add_module_import(python_type_module) + except NotImplementedError: + self.add_literal_import("typing", "Any") + column_python_type = "Any" + + if column.nullable: + self.add_literal_import("typing", "Optional") + column_python_type = f"Optional[{column_python_type}]" + return ( + f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}" + ) def render_relationship(self, relationship: RelationshipAttribute) -> str: def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str: @@ -1209,7 +1308,29 @@ def render_join(terms: list[JoinType]) -> str: rendered_relationship = render_callable( "relationship", repr(relationship.target.name), kwargs=kwargs ) - return f"{relationship.name} = {rendered_relationship}" + + if _sqla_version < (2, 0): + return f"{relationship.name} = {rendered_relationship}" + else: + relationship_type: str + if relationship.type == RelationshipType.ONE_TO_MANY: + self.add_literal_import("typing", "List") + relationship_type = f"List['{relationship.target.name}']" + elif relationship.type in ( + RelationshipType.ONE_TO_ONE, + RelationshipType.MANY_TO_ONE, + ): + relationship_type = f"'{relationship.target.name}'" + elif relationship.type == RelationshipType.MANY_TO_MANY: + self.add_literal_import("typing", "List") + relationship_type = f"List['{relationship.target.name}']" + else: + self.add_literal_import("typing", "Any") + relationship_type = "Any" + return ( + f"{relationship.name}: Mapped[{relationship_type}] " + f"= {rendered_relationship}" + ) class DataclassGenerator(DeclarativeGenerator): @@ -1234,112 +1355,159 @@ def __init__( self.metadata_key: str = metadata_key self.quote_annotations: bool = quote_annotations + def generate_base(self) -> None: + if _sqla_version < (2, 0): + self.base = Base( + literal_imports=[LiteralImport("sqlalchemy.orm", "registry")], + declarations=["mapper_registry = registry()"], + metadata_ref="metadata", + decorator="@mapper_registry.mapped", + ) + else: + self.base = Base( + literal_imports=[ + LiteralImport("sqlalchemy.orm", "DeclarativeBase"), + LiteralImport("sqlalchemy.orm", "MappedAsDataclass"), + ], + declarations=[ + ( + f"class {self.base_class_name}(MappedAsDataclass, " + "DeclarativeBase):" + ), + f"{self.indentation}pass", + ], + metadata_ref=f"{self.base_class_name}.metadata", + ) + def collect_imports(self, models: Iterable[Model]) -> None: super().collect_imports(models) - if not self.quote_annotations: - self.add_literal_import("__future__", "annotations") + if _sqla_version < (2, 0): + if not self.quote_annotations: + self.add_literal_import("__future__", "annotations") - if any(isinstance(model, ModelClass) for model in models): - self.remove_literal_import("sqlalchemy.orm", "declarative_base") - self.add_literal_import("dataclasses", "dataclass") - self.add_literal_import("dataclasses", "field") - self.add_literal_import("sqlalchemy.orm", "registry") + if any(isinstance(model, ModelClass) for model in models): + self.remove_literal_import("sqlalchemy.orm", "declarative_base") + self.add_literal_import("dataclasses", "dataclass") + self.add_literal_import("dataclasses", "field") + self.add_literal_import("sqlalchemy.orm", "registry") def collect_imports_for_model(self, model: Model) -> None: super().collect_imports_for_model(model) - if isinstance(model, ModelClass): - for column_attr in model.columns: - if column_attr.column.nullable: - self.add_literal_import("typing", "Optional") - break - - for relationship_attr in model.relationships: - if relationship_attr.type in ( - RelationshipType.ONE_TO_MANY, - RelationshipType.MANY_TO_MANY, - ): - self.add_literal_import("typing", "List") + if _sqla_version < (2, 0): + if isinstance(model, ModelClass): + for column_attr in model.columns: + if column_attr.column.nullable: + self.add_literal_import("typing", "Optional") + break + + for relationship_attr in model.relationships: + if relationship_attr.type in ( + RelationshipType.ONE_TO_MANY, + RelationshipType.MANY_TO_MANY, + ): + self.add_literal_import("typing", "List") def collect_imports_for_column(self, column: Column[Any]) -> None: super().collect_imports_for_column(column) - try: - python_type = column.type.python_type - except NotImplementedError: - pass - else: - self.add_import(python_type) + if _sqla_version < (2, 0): + try: + python_type = column.type.python_type + except NotImplementedError: + pass + else: + self.add_import(python_type) def render_module_variables(self, models: list[Model]) -> str: - if not any(isinstance(model, ModelClass) for model in models): + if _sqla_version >= (2, 0): return super().render_module_variables(models) + else: + if not any(isinstance(model, ModelClass) for model in models): + return super().render_module_variables(models) - declarations: list[str] = ["mapper_registry = registry()"] - if any(not isinstance(model, ModelClass) for model in models): - declarations.append("metadata = mapper_registry.metadata") + declarations: list[str] = ["mapper_registry = registry()"] + if any(not isinstance(model, ModelClass) for model in models): + declarations.append("metadata = mapper_registry.metadata") - if not self.quote_annotations: - self.add_literal_import("__future__", "annotations") + if not self.quote_annotations: + self.add_literal_import("__future__", "annotations") - return "\n".join(declarations) + return "\n".join(declarations) def render_class_declaration(self, model: ModelClass) -> str: - superclass_part = f"({model.parent_class.name})" if model.parent_class else "" - return ( - f"@mapper_registry.mapped\n@dataclass\nclass {model.name}{superclass_part}:" - ) + if _sqla_version >= (2, 0): + return super().render_class_declaration(model) + else: + superclass_part = ( + f"({model.parent_class.name})" if model.parent_class else "" + ) + return ( + f"@mapper_registry.mapped\n@dataclass" + f"\nclass {model.name}{superclass_part}:" + ) def render_class_variables(self, model: ModelClass) -> str: - variables = [ - super().render_class_variables(model), - f"__sa_dataclass_metadata_key__ = {self.metadata_key!r}", - ] - return "\n".join(variables) + if _sqla_version >= (2, 0): + return super().render_class_variables(model) + else: + variables = [ + super().render_class_variables(model), + f"__sa_dataclass_metadata_key__ = {self.metadata_key!r}", + ] + return "\n".join(variables) def render_column_attribute(self, column_attr: ColumnAttribute) -> str: - column = column_attr.column - try: - python_type = column.type.python_type - except NotImplementedError: - python_type_name = "Any" + if _sqla_version >= (2, 0): + return super().render_column_attribute(column_attr) else: - python_type_name = python_type.__name__ - - kwargs: dict[str, Any] = {} - if column.autoincrement and column.name in column.table.primary_key: - kwargs["init"] = False - elif column.nullable: - self.add_literal_import("typing", "Optional") - kwargs["default"] = None - python_type_name = f"Optional[{python_type_name}]" - - rendered_column = self.render_column(column, column_attr.name != column.name) - kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered_column}}}" - rendered_field = render_callable("field", kwargs=kwargs) - return f"{column_attr.name}: {python_type_name} = {rendered_field}" + column = column_attr.column + try: + python_type = column.type.python_type + except NotImplementedError: + python_type_name = "Any" + else: + python_type_name = python_type.__name__ + + kwargs: dict[str, Any] = {} + if column.autoincrement and column.name in column.table.primary_key: + kwargs["init"] = False + elif column.nullable: + self.add_literal_import("typing", "Optional") + kwargs["default"] = None + python_type_name = f"Optional[{python_type_name}]" + + rendered_column = self.render_column( + column, column_attr.name != column.name + ) + kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered_column}}}" + rendered_field = render_callable("field", kwargs=kwargs) + return f"{column_attr.name}: {python_type_name} = {rendered_field}" def render_relationship(self, relationship: RelationshipAttribute) -> str: - rendered = super().render_relationship(relationship).partition(" = ")[2] - kwargs: dict[str, Any] = {} + if _sqla_version >= (2, 0): + return super().render_relationship(relationship) + else: + rendered = super().render_relationship(relationship).partition(" = ")[2] + kwargs: dict[str, Any] = {} - annotation = relationship.target.name - if self.quote_annotations: - annotation = repr(relationship.target.name) + annotation = relationship.target.name + if self.quote_annotations: + annotation = repr(relationship.target.name) - if relationship.type in ( - RelationshipType.ONE_TO_MANY, - RelationshipType.MANY_TO_MANY, - ): - self.add_literal_import("typing", "List") - annotation = f"List[{annotation}]" - kwargs["default_factory"] = "list" - else: - self.add_literal_import("typing", "Optional") - kwargs["default"] = "None" - annotation = f"Optional[{annotation}]" + if relationship.type in ( + RelationshipType.ONE_TO_MANY, + RelationshipType.MANY_TO_MANY, + ): + self.add_literal_import("typing", "List") + annotation = f"List[{annotation}]" + kwargs["default_factory"] = "list" + else: + self.add_literal_import("typing", "Optional") + kwargs["default"] = "None" + annotation = f"Optional[{annotation}]" - kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered}}}" - rendered_field = render_callable("field", kwargs=kwargs) - return f"{relationship.name}: {annotation} = {rendered_field}" + kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered}}}" + rendered_field = render_callable("field", kwargs=kwargs) + return f"{relationship.name}: {annotation} = {rendered_field}" class SQLModelGenerator(DeclarativeGenerator): @@ -1360,6 +1528,13 @@ def __init__( base_class_name=base_class_name, ) + def generate_base(self) -> None: + self.base = Base( + literal_imports=[], + declarations=[], + metadata_ref="", + ) + def collect_imports(self, models: Iterable[Model]) -> None: super(DeclarativeGenerator, self).collect_imports(models) if any(isinstance(model, ModelClass) for model in models): @@ -1397,7 +1572,8 @@ def collect_imports_for_column(self, column: Column[Any]) -> None: def render_module_variables(self, models: list[Model]) -> str: declarations: list[str] = [] if any(not isinstance(model, ModelClass) for model in models): - declarations.append(f"metadata = {self.base_class_name}.metadata") + if self.base.table_metadata_declaration is not None: + declarations.append(self.base.table_metadata_declaration) return "\n".join(declarations) @@ -1470,7 +1646,7 @@ def render_relationship_args(self, arguments: str) -> list[str]: argument_list[-1] = argument_list[-1][:-1] argument_list = [argument[1:] for argument in argument_list] - rendered_args = [] + rendered_args: list[str] = [] for arg in argument_list: if "back_populates" in arg: rendered_args.append(arg) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..c43caaab --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,56 @@ +from textwrap import dedent + +import pytest +from pytest import FixtureRequest +from sqlalchemy.engine import Engine, create_engine +from sqlalchemy.orm import clear_mappers, configure_mappers +from sqlalchemy.schema import MetaData + +from sqlacodegen.generators import _sqla_version + + +def validate_code(generated_code: str, expected_code: str) -> None: + expected_code = dedent(expected_code) + assert generated_code == expected_code + try: + exec(generated_code, {}) + configure_mappers() + finally: + clear_mappers() + + +@pytest.fixture +def engine(request: FixtureRequest) -> Engine: + dialect = getattr(request, "param", None) + if dialect == "postgresql": + return create_engine("postgresql:///testdb") + elif dialect == "mysql": + return create_engine("mysql+mysqlconnector://testdb") + else: + return create_engine("sqlite:///:memory:") + + +@pytest.fixture +def metadata() -> MetaData: + return MetaData() + + +try: + import sqlmodel # noqa: F401 + + sqlmodel_installed = True +except ImportError: + sqlmodel_installed = False + + +requires_sqlmodel = pytest.mark.skipif( + not sqlmodel_installed, reason="Test needs sqlmodel package" +) + +requires_sqlalchemy_1_4 = pytest.mark.skipif( + _sqla_version >= (2, 0), reason="Test requires SQLAlchemy 1.4.x " +) + +requires_sqlalchemy_2_0 = pytest.mark.skipif( + _sqla_version < (2, 0), reason="Test requires SQLAlchemy 2.0.x or newer" +) diff --git a/tests/test_cli.py b/tests/test_cli.py index fde265b2..26972ae4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,6 +9,8 @@ from sqlacodegen.generators import _sqla_version +from .conftest import requires_sqlalchemy_1_4 + if sys.version_info < (3, 8): from importlib_metadata import version else: @@ -79,9 +81,10 @@ def test_cli_declarative(db_path: Path, tmp_path: Path) -> None: check=True, ) - assert ( - output_path.read_text() - == f"""\ + if _sqla_version < (2, 0): + assert ( + output_path.read_text() + == f"""\ from sqlalchemy import Column, Integer, Text from {declarative_package} import declarative_base @@ -94,7 +97,25 @@ class Foo(Base): id = Column(Integer, primary_key=True) name = Column(Text, nullable=False) """ - ) + ) + else: + assert ( + output_path.read_text() + == """\ +from sqlalchemy import Integer, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class Foo(Base): + __tablename__ = 'foo' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(Text) +""" + ) def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None: @@ -111,9 +132,10 @@ def test_cli_dataclass(db_path: Path, tmp_path: Path) -> None: check=True, ) - assert ( - output_path.read_text() - == f"""\ + if _sqla_version < (2, 0): + assert ( + output_path.read_text() + == f"""\ {future_imports}from dataclasses import dataclass, field from sqlalchemy import Column, Integer, Text @@ -131,9 +153,28 @@ class Foo: id: int = field(init=False, metadata={{'sa': Column(Integer, primary_key=True)}}) name: str = field(metadata={{'sa': Column(Text, nullable=False)}}) """ - ) + ) + else: + assert ( + output_path.read_text() + == """\ +from sqlalchemy import Integer, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class Foo(Base): + __tablename__ = 'foo' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(Text) +""" + ) +@requires_sqlalchemy_1_4 def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None: output_path = tmp_path / "outfile" subprocess.run( diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py new file mode 100644 index 00000000..20cd3bdf --- /dev/null +++ b/tests/test_generator_dataclass.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.engine import Engine +from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table +from sqlalchemy.sql.expression import text +from sqlalchemy.types import INTEGER, VARCHAR + +from sqlacodegen.generators import CodeGenerator, DataclassGenerator + +from .conftest import requires_sqlalchemy_1_4, validate_code + + +@requires_sqlalchemy_1_4 +class TestDataclassGenerator: + @pytest.fixture + def generator( + self, request: FixtureRequest, metadata: MetaData, engine: Engine + ) -> CodeGenerator: + options = getattr(request, "param", []) + return DataclassGenerator(metadata, engine, options) + + def test_basic_class(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("name", VARCHAR(20)), + ) + + validate_code( + generator.generate(), + """\ + from __future__ import annotations + + from dataclasses import dataclass, field + from typing import Optional + + from sqlalchemy import Column, Integer, String + from sqlalchemy.orm import registry + + mapper_registry = registry() + + + @mapper_registry.mapped + @dataclass + class Simple: + __tablename__ = 'simple' + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + name: Optional[str] = field(default=None, metadata={'sa': \ +Column(String(20))}) + """, + ) + + def test_mandatory_field_last(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("name", VARCHAR(20), server_default=text("foo")), + Column("age", INTEGER, nullable=False), + ) + + validate_code( + generator.generate(), + """\ + from __future__ import annotations + + from dataclasses import dataclass, field + from typing import Optional + + from sqlalchemy import Column, Integer, String, text + from sqlalchemy.orm import registry + + mapper_registry = registry() + + + @mapper_registry.mapped + @dataclass + class Simple: + __tablename__ = 'simple' + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + age: int = field(metadata={'sa': Column(Integer, nullable=False)}) + name: Optional[str] = field(default=None, metadata={'sa': \ +Column(String(20), server_default=text('foo'))}) + """, + ) + + def test_onetomany_optional(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from __future__ import annotations + + from dataclasses import dataclass, field + from typing import List, Optional + + from sqlalchemy import Column, ForeignKey, Integer + from sqlalchemy.orm import registry, relationship + + mapper_registry = registry() + + + @mapper_registry.mapped + @dataclass + class SimpleContainers: + __tablename__ = 'simple_containers' + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + + simple_items: List[SimpleItems] = field(default_factory=list, \ +metadata={'sa': relationship('SimpleItems', back_populates='container')}) + + + @mapper_registry.mapped + @dataclass + class SimpleItems: + __tablename__ = 'simple_items' + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + container_id: Optional[int] = field(default=None, \ +metadata={'sa': Column(ForeignKey('simple_containers.id'))}) + + container: Optional[SimpleContainers] = field(default=None, \ +metadata={'sa': relationship('SimpleContainers', back_populates='simple_items')}) + """, + ) + + def test_manytomany(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "container_items", + generator.metadata, + Column("item_id", INTEGER), + Column("container_id", INTEGER), + ForeignKeyConstraint(["item_id"], ["simple_items.id"]), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + + validate_code( + generator.generate(), + """\ + from __future__ import annotations + + from dataclasses import dataclass, field + from typing import List + + from sqlalchemy import Column, ForeignKey, Integer, Table + from sqlalchemy.orm import registry, relationship + + mapper_registry = registry() + metadata = mapper_registry.metadata + + + @mapper_registry.mapped + @dataclass + class SimpleContainers: + __tablename__ = 'simple_containers' + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + + item: List[SimpleItems] = field(default_factory=list, metadata=\ +{'sa': relationship('SimpleItems', secondary='container_items', \ +back_populates='container')}) + + + @mapper_registry.mapped + @dataclass + class SimpleItems: + __tablename__ = 'simple_items' + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + + container: List[SimpleContainers] = \ +field(default_factory=list, metadata={'sa': relationship('SimpleContainers', \ +secondary='container_items', back_populates='item')}) + + + t_container_items = Table( + 'container_items', metadata, + Column('item_id', ForeignKey('simple_items.id')), + Column('container_id', ForeignKey('simple_containers.id')) + ) + """, + ) + + def test_named_foreign_key_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint( + ["container_id"], ["simple_containers.id"], name="foreignkeytest" + ), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from __future__ import annotations + + from dataclasses import dataclass, field + from typing import List, Optional + + from sqlalchemy import Column, ForeignKeyConstraint, Integer + from sqlalchemy.orm import registry, relationship + + mapper_registry = registry() + + + @mapper_registry.mapped + @dataclass + class SimpleContainers: + __tablename__ = 'simple_containers' + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + + simple_items: List[SimpleItems] = field(default_factory=list, \ +metadata={'sa': relationship('SimpleItems', back_populates='container')}) + + + @mapper_registry.mapped + @dataclass + class SimpleItems: + __tablename__ = 'simple_items' + __table_args__ = ( + ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ +name='foreignkeytest'), + ) + __sa_dataclass_metadata_key__ = 'sa' + + id: int = field(init=False, metadata={'sa': Column(Integer, \ +primary_key=True)}) + container_id: Optional[int] = field(default=None, metadata={'sa': \ +Column(Integer)}) + + container: Optional[SimpleContainers] = field(default=None, \ +metadata={'sa': relationship('SimpleContainers', back_populates='simple_items')}) + """, + ) + + def test_uuid_type_annotation(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", UUID, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from __future__ import annotations + + from dataclasses import dataclass, field + + from sqlalchemy import Column + from sqlalchemy.dialects.postgresql import UUID + from sqlalchemy.orm import registry + + mapper_registry = registry() + + + @mapper_registry.mapped + @dataclass + class Simple: + __tablename__ = 'simple' + __sa_dataclass_metadata_key__ = 'sa' + + id: str = field(init=False, metadata={'sa': \ +Column(UUID, primary_key=True)}) + """, + ) diff --git a/tests/test_generator_dataclass2.py b/tests/test_generator_dataclass2.py new file mode 100644 index 00000000..d17c1f60 --- /dev/null +++ b/tests/test_generator_dataclass2.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.engine import Engine +from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table +from sqlalchemy.sql.expression import text +from sqlalchemy.types import INTEGER, VARCHAR + +from sqlacodegen.generators import CodeGenerator, DataclassGenerator + +from .conftest import requires_sqlalchemy_2_0, validate_code + + +@requires_sqlalchemy_2_0 +class TestDataclassGenerator: + @pytest.fixture + def generator( + self, request: FixtureRequest, metadata: MetaData, engine: Engine + ) -> CodeGenerator: + options = getattr(request, "param", []) + return DataclassGenerator(metadata, engine, options) + + def test_basic_class(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("name", VARCHAR(20)), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Integer, String +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[Optional[str]] = mapped_column(String(20)) + """, + ) + + def test_mandatory_field_last(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("name", VARCHAR(20), server_default=text("foo")), + Column("age", INTEGER, nullable=False), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Integer, String, text +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + age: Mapped[int] = mapped_column(Integer) + name: Mapped[Optional[str]] = mapped_column(String(20), server_default=text('foo')) + """, + ) + + def test_onetomany_optional(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, \ +relationship + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + container_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_containers.id')) + + container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +back_populates='simple_items') + """, + ) + + def test_manytomany(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "container_items", + generator.metadata, + Column("item_id", INTEGER), + Column("container_id", INTEGER), + ForeignKeyConstraint(["item_id"], ["simple_items.id"]), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + + validate_code( + generator.generate(), + """\ +from typing import List + +from sqlalchemy import Column, ForeignKey, Integer, Table +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, \ +relationship + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + item: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +secondary='container_items', back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + container: Mapped[List['SimpleContainers']] = relationship('SimpleContainers', \ +secondary='container_items', back_populates='item') + + +t_container_items = Table( + 'container_items', Base.metadata, + Column('item_id', ForeignKey('simple_items.id')), + Column('container_id', ForeignKey('simple_containers.id')) +) + """, + ) + + def test_named_foreign_key_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint( + ["container_id"], ["simple_containers.id"], name="foreignkeytest" + ), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKeyConstraint, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column, \ +relationship + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ +name='foreignkeytest'), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + container_id: Mapped[Optional[int]] = mapped_column(Integer) + + container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +back_populates='simple_items') + """, + ) + + def test_uuid_type_annotation(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", UUID, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import UUID +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column +import uuid + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True) + """, + ) diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py new file mode 100644 index 00000000..79dde063 --- /dev/null +++ b/tests/test_generator_declarative.py @@ -0,0 +1,1378 @@ +from __future__ import annotations + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.engine import Engine +from sqlalchemy.schema import ( + CheckConstraint, + Column, + ForeignKey, + ForeignKeyConstraint, + Index, + MetaData, + Table, + UniqueConstraint, +) +from sqlalchemy.sql.expression import text +from sqlalchemy.types import INTEGER, VARCHAR, Text + +from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator + +from .conftest import requires_sqlalchemy_1_4, validate_code + + +@requires_sqlalchemy_1_4 +class TestDeclarativeGenerator: + """Test declarative mapping generation vor SQLAlchemy 1.4""" + + @pytest.fixture + def generator( + self, request: FixtureRequest, metadata: MetaData, engine: Engine + ) -> CodeGenerator: + options = getattr(request, "param", []) + return DeclarativeGenerator(metadata, engine, options) + + def test_indexes(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", INTEGER), + Column("text", VARCHAR), + ) + simple_items.indexes.add(Index("idx_number", simple_items.c.number)) + simple_items.indexes.add( + Index("idx_text_number", simple_items.c.text, simple_items.c.number) + ) + simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Index, Integer, String + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + Index('idx_number', 'number'), + Index('idx_text', 'text', unique=True), + Index('idx_text_number', 'text', 'number') + ) + + id = Column(Integer, primary_key=True) + number = Column(Integer) + text = Column(String) + """, + ) + + def test_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", INTEGER), + CheckConstraint("number > 0"), + UniqueConstraint("id", "number"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import CheckConstraint, Column, Integer, UniqueConstraint + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + CheckConstraint('number > 0'), + UniqueConstraint('id', 'number') + ) + + id = Column(Integer, primary_key=True) + number = Column(Integer) + """, + ) + + def test_onetomany(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, ForeignKey, Integer + from sqlalchemy.orm import declarative_base, relationship + + Base = declarative_base() + + + class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id = Column(Integer, primary_key=True) + + simple_items = relationship('SimpleItems', back_populates='container') + + + class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + container_id = Column(ForeignKey('simple_containers.id')) + + container = relationship('SimpleContainers', \ +back_populates='simple_items') + """, + ) + + def test_onetomany_selfref(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("parent_item_id", INTEGER), + ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, ForeignKey, Integer + from sqlalchemy.orm import declarative_base, relationship + + Base = declarative_base() + + + class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + parent_item_id = Column(ForeignKey('simple_items.id')) + + parent_item = relationship('SimpleItems', remote_side=[id], \ +back_populates='parent_item_reverse') + parent_item_reverse = relationship('SimpleItems', \ +remote_side=[parent_item_id], back_populates='parent_item') + """, + ) + + def test_onetomany_selfref_multi(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("parent_item_id", INTEGER), + Column("top_item_id", INTEGER), + ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), + ForeignKeyConstraint(["top_item_id"], ["simple_items.id"]), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, ForeignKey, Integer + from sqlalchemy.orm import declarative_base, relationship + + Base = declarative_base() + + + class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + parent_item_id = Column(ForeignKey('simple_items.id')) + top_item_id = Column(ForeignKey('simple_items.id')) + + parent_item = relationship('SimpleItems', remote_side=[id], \ +foreign_keys=[parent_item_id], back_populates='parent_item_reverse') + parent_item_reverse = relationship('SimpleItems', \ +remote_side=[parent_item_id], foreign_keys=[parent_item_id], \ +back_populates='parent_item') + top_item = relationship('SimpleItems', remote_side=[id], \ +foreign_keys=[top_item_id], back_populates='top_item_reverse') + top_item_reverse = relationship('SimpleItems', \ +remote_side=[top_item_id], foreign_keys=[top_item_id], back_populates='top_item') + """, + ) + + def test_onetomany_composite(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id1", INTEGER), + Column("container_id2", INTEGER), + ForeignKeyConstraint( + ["container_id1", "container_id2"], + ["simple_containers.id1", "simple_containers.id2"], + ondelete="CASCADE", + onupdate="CASCADE", + ), + ) + Table( + "simple_containers", + generator.metadata, + Column("id1", INTEGER, primary_key=True), + Column("id2", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKeyConstraint, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id1 = Column(Integer, primary_key=True, nullable=False) + id2 = Column(Integer, primary_key=True, nullable=False) + + simple_items = relationship('SimpleItems', back_populates='simple_containers') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + ForeignKeyConstraint(['container_id1', 'container_id2'], \ +['simple_containers.id1', 'simple_containers.id2'], ondelete='CASCADE', \ +onupdate='CASCADE'), + ) + + id = Column(Integer, primary_key=True) + container_id1 = Column(Integer) + container_id2 = Column(Integer) + + simple_containers = relationship('SimpleContainers', back_populates='simple_items') + """, + ) + + def test_onetomany_multiref(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("parent_container_id", INTEGER), + Column("top_container_id", INTEGER), + ForeignKeyConstraint(["parent_container_id"], ["simple_containers.id"]), + ForeignKeyConstraint(["top_container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id = Column(Integer, primary_key=True) + + simple_items = relationship('SimpleItems', \ +foreign_keys='[SimpleItems.parent_container_id]', back_populates='parent_container') + simple_items_ = relationship('SimpleItems', \ +foreign_keys='[SimpleItems.top_container_id]', back_populates='top_container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + parent_container_id = Column(ForeignKey('simple_containers.id')) + top_container_id = Column(ForeignKey('simple_containers.id')) + + parent_container = relationship('SimpleContainers', \ +foreign_keys=[parent_container_id], back_populates='simple_items') + top_container = relationship('SimpleContainers', \ +foreign_keys=[top_container_id], back_populates='simple_items_') + """, + ) + + def test_onetoone(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("other_item_id", INTEGER), + ForeignKeyConstraint(["other_item_id"], ["other_items.id"]), + UniqueConstraint("other_item_id"), + ) + Table( + "other_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class OtherItems(Base): + __tablename__ = 'other_items' + + id = Column(Integer, primary_key=True) + + simple_items = relationship('SimpleItems', uselist=False, \ +back_populates='other_item') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + other_item_id = Column(ForeignKey('other_items.id'), unique=True) + + other_item = relationship('OtherItems', back_populates='simple_items') + """, + ) + + def test_onetomany_noinflect(self, generator: CodeGenerator) -> None: + Table( + "oglkrogk", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("fehwiuhfiwID", INTEGER), + ForeignKeyConstraint(["fehwiuhfiwID"], ["fehwiuhfiw.id"]), + ) + Table("fehwiuhfiw", generator.metadata, Column("id", INTEGER, primary_key=True)) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class Fehwiuhfiw(Base): + __tablename__ = 'fehwiuhfiw' + + id = Column(Integer, primary_key=True) + + oglkrogk = relationship('Oglkrogk', back_populates='fehwiuhfiw') + + +class Oglkrogk(Base): + __tablename__ = 'oglkrogk' + + id = Column(Integer, primary_key=True) + fehwiuhfiwID = Column(ForeignKey('fehwiuhfiw.id')) + + fehwiuhfiw = relationship('Fehwiuhfiw', back_populates='oglkrogk') + """, + ) + + def test_onetomany_conflicting_column(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("relationship", Text), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer, Text +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id = Column(Integer, primary_key=True) + relationship_ = Column('relationship', Text) + + simple_items = relationship('SimpleItems', back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + container_id = Column(ForeignKey('simple_containers.id')) + + container = relationship('SimpleContainers', back_populates='simple_items') + """, + ) + + def test_onetomany_conflicting_relationship(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("relationship_id", INTEGER), + ForeignKeyConstraint(["relationship_id"], ["relationship.id"]), + ) + Table( + "relationship", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class Relationship(Base): + __tablename__ = 'relationship' + + id = Column(Integer, primary_key=True) + + simple_items = relationship('SimpleItems', back_populates='relationship_') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + relationship_id = Column(ForeignKey('relationship.id')) + + relationship_ = relationship('Relationship', back_populates='simple_items') + """, + ) + + @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) + def test_manytoone_nobidi(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id = Column(Integer, primary_key=True) + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + container_id = Column(ForeignKey('simple_containers.id')) + + container = relationship('SimpleContainers') + """, + ) + + def test_manytomany(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "container_items", + generator.metadata, + Column("item_id", INTEGER), + Column("container_id", INTEGER), + ForeignKeyConstraint(["item_id"], ["simple_items.id"]), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer, Table +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() +metadata = Base.metadata + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id = Column(Integer, primary_key=True) + + item = relationship('SimpleItems', secondary='container_items', \ +back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + + container = relationship('SimpleContainers', secondary='container_items', \ +back_populates='item') + + +t_container_items = Table( + 'container_items', metadata, + Column('item_id', ForeignKey('simple_items.id')), + Column('container_id', ForeignKey('simple_containers.id')) +) + """, + ) + + @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) + def test_manytomany_nobidi(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "container_items", + generator.metadata, + Column("item_id", INTEGER), + Column("container_id", INTEGER), + ForeignKeyConstraint(["item_id"], ["simple_items.id"]), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer, Table +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() +metadata = Base.metadata + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id = Column(Integer, primary_key=True) + + item = relationship('SimpleItems', secondary='container_items') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + + +t_container_items = Table( + 'container_items', metadata, + Column('item_id', ForeignKey('simple_items.id')), + Column('container_id', ForeignKey('simple_containers.id')) +) + """, + ) + + def test_manytomany_selfref(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + Table( + "child_items", + generator.metadata, + Column("parent_id", INTEGER), + Column("child_id", INTEGER), + ForeignKeyConstraint(["parent_id"], ["simple_items.id"]), + ForeignKeyConstraint(["child_id"], ["simple_items.id"]), + schema="otherschema", + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer, Table +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() +metadata = Base.metadata + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + + parent = relationship('SimpleItems', secondary='otherschema.child_items', \ +primaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \ +secondaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \ +back_populates='child') + child = relationship('SimpleItems', secondary='otherschema.child_items', \ +primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \ +secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \ +back_populates='parent') + + +t_child_items = Table( + 'child_items', metadata, + Column('parent_id', ForeignKey('simple_items.id')), + Column('child_id', ForeignKey('simple_items.id')), + schema='otherschema' +) + """, + ) + + def test_manytomany_composite(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id1", INTEGER, primary_key=True), + Column("id2", INTEGER, primary_key=True), + ) + Table( + "simple_containers", + generator.metadata, + Column("id1", INTEGER, primary_key=True), + Column("id2", INTEGER, primary_key=True), + ) + Table( + "container_items", + generator.metadata, + Column("item_id1", INTEGER), + Column("item_id2", INTEGER), + Column("container_id1", INTEGER), + Column("container_id2", INTEGER), + ForeignKeyConstraint( + ["item_id1", "item_id2"], ["simple_items.id1", "simple_items.id2"] + ), + ForeignKeyConstraint( + ["container_id1", "container_id2"], + ["simple_containers.id1", "simple_containers.id2"], + ), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() +metadata = Base.metadata + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id1 = Column(Integer, primary_key=True, nullable=False) + id2 = Column(Integer, primary_key=True, nullable=False) + + simple_items = relationship('SimpleItems', secondary='container_items', \ +back_populates='simple_containers') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id1 = Column(Integer, primary_key=True, nullable=False) + id2 = Column(Integer, primary_key=True, nullable=False) + + simple_containers = relationship('SimpleContainers', secondary='container_items', \ +back_populates='simple_items') + + +t_container_items = Table( + 'container_items', metadata, + Column('item_id1', Integer), + Column('item_id2', Integer), + Column('container_id1', Integer), + Column('container_id2', Integer), + ForeignKeyConstraint(['container_id1', 'container_id2'], \ +['simple_containers.id1', 'simple_containers.id2']), + ForeignKeyConstraint(['item_id1', 'item_id2'], ['simple_items.id1', \ +'simple_items.id2']) +) + """, + ) + + def test_joined_inheritance(self, generator: CodeGenerator) -> None: + Table( + "simple_sub_items", + generator.metadata, + Column("simple_items_id", INTEGER, primary_key=True), + Column("data3", INTEGER), + ForeignKeyConstraint(["simple_items_id"], ["simple_items.super_item_id"]), + ) + Table( + "simple_super_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("data1", INTEGER), + ) + Table( + "simple_items", + generator.metadata, + Column("super_item_id", INTEGER, primary_key=True), + Column("data2", INTEGER), + ForeignKeyConstraint(["super_item_id"], ["simple_super_items.id"]), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class SimpleSuperItems(Base): + __tablename__ = 'simple_super_items' + + id = Column(Integer, primary_key=True) + data1 = Column(Integer) + + +class SimpleItems(SimpleSuperItems): + __tablename__ = 'simple_items' + + super_item_id = Column(ForeignKey('simple_super_items.id'), primary_key=True) + data2 = Column(Integer) + + +class SimpleSubItems(SimpleItems): + __tablename__ = 'simple_sub_items' + + simple_items_id = Column(ForeignKey('simple_items.super_item_id'), primary_key=True) + data3 = Column(Integer) + """, + ) + + def test_joined_inheritance_same_table_name(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "simple", + generator.metadata, + Column("id", INTEGER, ForeignKey("simple.id"), primary_key=True), + schema="altschema", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, ForeignKey, Integer + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class Simple(Base): + __tablename__ = 'simple' + + id = Column(Integer, primary_key=True) + + + class Simple_(Simple): + __tablename__ = 'simple' + __table_args__ = {'schema': 'altschema'} + + id = Column(ForeignKey('simple.id'), primary_key=True) + """, + ) + + @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) + def test_use_inflect(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + + Table("singular", generator.metadata, Column("id", INTEGER, primary_key=True)) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class SimpleItem(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + + +class Singular(Base): + __tablename__ = 'singular' + + id = Column(Integer, primary_key=True) + """, + ) + + @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) + @pytest.mark.parametrize( + argnames=("table_name", "class_name", "relationship_name"), + argvalues=[ + ("manufacturers", "manufacturer", "manufacturer"), + ("statuses", "status", "status"), + ("studies", "study", "study"), + ("moose", "moose", "moose"), + ], + ids=[ + "test_inflect_manufacturer", + "test_inflect_status", + "test_inflect_study", + "test_inflect_moose", + ], + ) + def test_use_inflect_plural( + self, + generator: CodeGenerator, + table_name: str, + class_name: str, + relationship_name: str, + ) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column(f"{relationship_name}_id", INTEGER), + ForeignKeyConstraint([f"{relationship_name}_id"], [f"{table_name}.id"]), + UniqueConstraint(f"{relationship_name}_id"), + ) + Table(table_name, generator.metadata, Column("id", INTEGER, primary_key=True)) + + validate_code( + generator.generate(), + f"""\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class {class_name.capitalize()}(Base): + __tablename__ = '{table_name}' + + id = Column(Integer, primary_key=True) + + simple_item = relationship('SimpleItem', uselist=False, \ +back_populates='{relationship_name}') + + +class SimpleItem(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + {relationship_name}_id = Column(ForeignKey('{table_name}.id'), unique=True) + + {relationship_name} = relationship('{class_name.capitalize()}', \ +back_populates='simple_item') + """, + ) + + def test_table_kwargs(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + schema="testschema", + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = {'schema': 'testschema'} + + id = Column(Integer, primary_key=True) + """, + ) + + def test_table_args_kwargs(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("name", VARCHAR), + schema="testschema", + ) + simple_items.indexes.add( + Index("testidx", simple_items.c.id, simple_items.c.name) + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Index, Integer, String +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + Index('testidx', 'id', 'name'), + {'schema': 'testschema'} + ) + + id = Column(Integer, primary_key=True) + name = Column(String) + """, + ) + + def test_foreign_key_schema(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("other_item_id", INTEGER), + ForeignKeyConstraint(["other_item_id"], ["otherschema.other_items.id"]), + ) + Table( + "other_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + schema="otherschema", + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, ForeignKey, Integer +from sqlalchemy.orm import declarative_base, relationship + +Base = declarative_base() + + +class OtherItems(Base): + __tablename__ = 'other_items' + __table_args__ = {'schema': 'otherschema'} + + id = Column(Integer, primary_key=True) + + simple_items = relationship('SimpleItems', back_populates='other_item') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id = Column(Integer, primary_key=True) + other_item_id = Column(ForeignKey('otherschema.other_items.id')) + + other_item = relationship('OtherItems', back_populates='simple_items') + """, + ) + + def test_invalid_attribute_names(self, generator: CodeGenerator) -> None: + Table( + "simple-items", + generator.metadata, + Column("id-test", INTEGER, primary_key=True), + Column("4test", INTEGER), + Column("_4test", INTEGER), + Column("def", INTEGER), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class SimpleItems(Base): + __tablename__ = 'simple-items' + + id_test = Column('id-test', Integer, primary_key=True) + _4test = Column('4test', Integer) + _4test_ = Column('_4test', Integer) + def_ = Column('def', Integer) + """, + ) + + def test_pascal(self, generator: CodeGenerator) -> None: + Table( + "CustomerAPIPreference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class CustomerAPIPreference(Base): + __tablename__ = 'CustomerAPIPreference' + + id = Column(Integer, primary_key=True) + """, + ) + + def test_underscore(self, generator: CodeGenerator) -> None: + Table( + "customer_api_preference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class CustomerApiPreference(Base): + __tablename__ = 'customer_api_preference' + + id = Column(Integer, primary_key=True) + """, + ) + + def test_pascal_underscore(self, generator: CodeGenerator) -> None: + Table( + "customer_API_Preference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class CustomerAPIPreference(Base): + __tablename__ = 'customer_API_Preference' + + id = Column(Integer, primary_key=True) + """, + ) + + def test_pascal_multiple_underscore(self, generator: CodeGenerator) -> None: + Table( + "customer_API__Preference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class CustomerAPIPreference(Base): + __tablename__ = 'customer_API__Preference' + + id = Column(Integer, primary_key=True) + """, + ) + + @pytest.mark.parametrize( + "generator, nocomments", + [([], False), (["nocomments"], True)], + indirect=["generator"], + ) + def test_column_comment(self, generator: CodeGenerator, nocomments: bool) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True, comment="this is a 'comment'"), + ) + + comment_part = "" if nocomments else ", comment=\"this is a 'comment'\"" + validate_code( + generator.generate(), + f"""\ + from sqlalchemy import Column, Integer + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class Simple(Base): + __tablename__ = 'simple' + + id = Column(Integer, primary_key=True{comment_part}) + """, + ) + + def test_table_comment(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + comment="this is a 'comment'", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class Simple(Base): + __tablename__ = 'simple' + __table_args__ = {'comment': "this is a 'comment'"} + + id = Column(Integer, primary_key=True) + """, + ) + + def test_metadata_column(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("metadata", VARCHAR), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, String + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class Simple(Base): + __tablename__ = 'simple' + + id = Column(Integer, primary_key=True) + metadata_ = Column('metadata', String) + """, + ) + + def test_invalid_variable_name_from_column(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column(" id ", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class Simple(Base): + __tablename__ = 'simple' + + id = Column(' id ', Integer, primary_key=True) + """, + ) + + def test_only_tables(self, generator: CodeGenerator) -> None: + Table("simple", generator.metadata, Column("id", INTEGER)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple = Table( + 'simple', metadata, + Column('id', Integer) + ) + """, + ) + + def test_named_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER), + Column("text", VARCHAR), + CheckConstraint("id > 0", name="checktest"), + PrimaryKeyConstraint("id", name="primarytest"), + UniqueConstraint("text", name="uniquetest"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import CheckConstraint, Column, Integer, \ +PrimaryKeyConstraint, String, UniqueConstraint + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class Simple(Base): + __tablename__ = 'simple' + __table_args__ = ( + CheckConstraint('id > 0', name='checktest'), + PrimaryKeyConstraint('id', name='primarytest'), + UniqueConstraint('text', name='uniquetest') + ) + + id = Column(Integer, primary_key=True) + text = Column(String) + """, + ) + + def test_named_foreign_key_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint( + ["container_id"], ["simple_containers.id"], name="foreignkeytest" + ), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, ForeignKeyConstraint, Integer + from sqlalchemy.orm import declarative_base, relationship + + Base = declarative_base() + + + class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id = Column(Integer, primary_key=True) + + simple_items = relationship('SimpleItems', back_populates='container') + + + class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ +name='foreignkeytest'), + ) + + id = Column(Integer, primary_key=True) + container_id = Column(Integer) + + container = relationship('SimpleContainers', \ +back_populates='simple_items') + """, + ) + + # @pytest.mark.xfail(strict=True) + def test_colname_import_conflict(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("text", VARCHAR), + Column("textwithdefault", VARCHAR, server_default=text("'test'")), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, String, text + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + + class Simple(Base): + __tablename__ = 'simple' + + id = Column(Integer, primary_key=True) + text_ = Column('text', String) + textwithdefault = Column(String, server_default=text("'test'")) + """, + ) diff --git a/tests/test_generator_declarative2.py b/tests/test_generator_declarative2.py new file mode 100644 index 00000000..847a768f --- /dev/null +++ b/tests/test_generator_declarative2.py @@ -0,0 +1,1490 @@ +from __future__ import annotations + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.engine import Engine +from sqlalchemy.schema import ( + CheckConstraint, + Column, + ForeignKey, + ForeignKeyConstraint, + Index, + MetaData, + Table, + UniqueConstraint, +) +from sqlalchemy.sql.expression import text +from sqlalchemy.types import INTEGER, VARCHAR, Text + +from sqlacodegen.generators import CodeGenerator, DeclarativeGenerator + +from .conftest import requires_sqlalchemy_2_0, validate_code + + +@requires_sqlalchemy_2_0 +class TestDeclarativeGenerator2: + """Test declarative mapping generation vor SQLAlchemy 2.0""" + + @pytest.fixture + def generator( + self, request: FixtureRequest, metadata: MetaData, engine: Engine + ) -> CodeGenerator: + options = getattr(request, "param", []) + return DeclarativeGenerator(metadata, engine, options) + + def test_indexes(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", INTEGER), + Column("text", VARCHAR), + ) + simple_items.indexes.add(Index("idx_number", simple_items.c.number)) + simple_items.indexes.add( + Index("idx_text_number", simple_items.c.text, simple_items.c.number) + ) + simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True)) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Index, Integer, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + Index('idx_number', 'number'), + Index('idx_text', 'text', unique=True), + Index('idx_text_number', 'text', 'number') + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + number: Mapped[Optional[int]] = mapped_column(Integer) + text: Mapped[Optional[str]] = mapped_column(String) + """, + ) + + def test_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", INTEGER), + CheckConstraint("number > 0"), + UniqueConstraint("id", "number"), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import CheckConstraint, Integer, UniqueConstraint +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + CheckConstraint('number > 0'), + UniqueConstraint('id', 'number') + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + number: Mapped[Optional[int]] = mapped_column(Integer) + """, + ) + + def test_onetomany(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + container_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_containers.id')) + + container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +back_populates='simple_items') + """, + ) + + def test_onetomany_selfref(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("parent_item_id", INTEGER), + ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + parent_item_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_items.id')) + + parent_item: Mapped['SimpleItems'] = relationship('SimpleItems', \ +remote_side=[id], back_populates='parent_item_reverse') + parent_item_reverse: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +remote_side=[parent_item_id], back_populates='parent_item') +""", + ) + + def test_onetomany_selfref_multi(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("parent_item_id", INTEGER), + Column("top_item_id", INTEGER), + ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), + ForeignKeyConstraint(["top_item_id"], ["simple_items.id"]), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + parent_item_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_items.id')) + top_item_id: Mapped[Optional[int]] = mapped_column(ForeignKey('simple_items.id')) + + parent_item: Mapped['SimpleItems'] = relationship('SimpleItems', \ +remote_side=[id], foreign_keys=[parent_item_id], back_populates='parent_item_reverse') + parent_item_reverse: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +remote_side=[parent_item_id], foreign_keys=[parent_item_id], \ +back_populates='parent_item') + top_item: Mapped['SimpleItems'] = relationship('SimpleItems', remote_side=[id], \ +foreign_keys=[top_item_id], back_populates='top_item_reverse') + top_item_reverse: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +remote_side=[top_item_id], foreign_keys=[top_item_id], back_populates='top_item') + """, + ) + + def test_onetomany_composite(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id1", INTEGER), + Column("container_id2", INTEGER), + ForeignKeyConstraint( + ["container_id1", "container_id2"], + ["simple_containers.id1", "simple_containers.id2"], + ondelete="CASCADE", + onupdate="CASCADE", + ), + ) + Table( + "simple_containers", + generator.metadata, + Column("id1", INTEGER, primary_key=True), + Column("id2", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKeyConstraint, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id1: Mapped[int] = mapped_column(Integer, primary_key=True) + id2: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='simple_containers') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + ForeignKeyConstraint(['container_id1', 'container_id2'], \ +['simple_containers.id1', 'simple_containers.id2'], ondelete='CASCADE', \ +onupdate='CASCADE'), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + container_id1: Mapped[Optional[int]] = mapped_column(Integer) + container_id2: Mapped[Optional[int]] = mapped_column(Integer) + + simple_containers: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +back_populates='simple_items') + """, + ) + + def test_onetomany_multiref(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("parent_container_id", INTEGER), + Column("top_container_id", INTEGER), + ForeignKeyConstraint(["parent_container_id"], ["simple_containers.id"]), + ForeignKeyConstraint(["top_container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +foreign_keys='[SimpleItems.parent_container_id]', back_populates='parent_container') + simple_items_: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +foreign_keys='[SimpleItems.top_container_id]', back_populates='top_container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + parent_container_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_containers.id')) + top_container_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_containers.id')) + + parent_container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +foreign_keys=[parent_container_id], back_populates='simple_items') + top_container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +foreign_keys=[top_container_id], back_populates='simple_items_') + """, + ) + + def test_onetoone(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("other_item_id", INTEGER), + ForeignKeyConstraint(["other_item_id"], ["other_items.id"]), + UniqueConstraint("other_item_id"), + ) + Table( + "other_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class OtherItems(Base): + __tablename__ = 'other_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped['SimpleItems'] = relationship('SimpleItems', uselist=False, \ +back_populates='other_item') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + other_item_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('other_items.id'), unique=True) + + other_item: Mapped['OtherItems'] = relationship('OtherItems', \ +back_populates='simple_items') + """, + ) + + def test_onetomany_noinflect(self, generator: CodeGenerator) -> None: + Table( + "oglkrogk", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("fehwiuhfiwID", INTEGER), + ForeignKeyConstraint(["fehwiuhfiwID"], ["fehwiuhfiw.id"]), + ) + Table("fehwiuhfiw", generator.metadata, Column("id", INTEGER, primary_key=True)) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class Fehwiuhfiw(Base): + __tablename__ = 'fehwiuhfiw' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + oglkrogk: Mapped[List['Oglkrogk']] = relationship('Oglkrogk', \ +back_populates='fehwiuhfiw') + + +class Oglkrogk(Base): + __tablename__ = 'oglkrogk' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + fehwiuhfiwID: Mapped[Optional[int]] = mapped_column(ForeignKey('fehwiuhfiw.id')) + + fehwiuhfiw: Mapped['Fehwiuhfiw'] = \ +relationship('Fehwiuhfiw', back_populates='oglkrogk') + """, + ) + + def test_onetomany_conflicting_column(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("relationship", Text), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer, Text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + relationship_: Mapped[Optional[str]] = mapped_column('relationship', Text) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + container_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_containers.id')) + + container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +back_populates='simple_items') + """, + ) + + def test_onetomany_conflicting_relationship(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("relationship_id", INTEGER), + ForeignKeyConstraint(["relationship_id"], ["relationship.id"]), + ) + Table( + "relationship", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class Relationship(Base): + __tablename__ = 'relationship' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='relationship_') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + relationship_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('relationship.id')) + + relationship_: Mapped['Relationship'] = relationship('Relationship', \ +back_populates='simple_items') + """, + ) + + @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) + def test_manytoone_nobidi(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + container_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('simple_containers.id')) + + container: Mapped['SimpleContainers'] = relationship('SimpleContainers') +""", + ) + + def test_manytomany(self, generator: CodeGenerator) -> None: + Table("left_table", generator.metadata, Column("id", INTEGER, primary_key=True)) + Table( + "right_table", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "association_table", + generator.metadata, + Column("left_id", INTEGER), + Column("right_id", INTEGER), + ForeignKeyConstraint(["left_id"], ["left_table.id"]), + ForeignKeyConstraint(["right_id"], ["right_table.id"]), + ) + + validate_code( + generator.generate(), + """\ +from typing import List + +from sqlalchemy import Column, ForeignKey, Integer, Table +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class LeftTable(Base): + __tablename__ = 'left_table' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + right: Mapped[List['RightTable']] = relationship('RightTable', \ +secondary='association_table', back_populates='left') + + +class RightTable(Base): + __tablename__ = 'right_table' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + left: Mapped[List['LeftTable']] = relationship('LeftTable', \ +secondary='association_table', back_populates='right') + + +t_association_table = Table( + 'association_table', Base.metadata, + Column('left_id', ForeignKey('left_table.id')), + Column('right_id', ForeignKey('right_table.id')) +) + """, + ) + + @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) + def test_manytomany_nobidi(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "container_items", + generator.metadata, + Column("item_id", INTEGER), + Column("container_id", INTEGER), + ForeignKeyConstraint(["item_id"], ["simple_items.id"]), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + + validate_code( + generator.generate(), + """\ +from typing import List + +from sqlalchemy import Column, ForeignKey, Integer, Table +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + item: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +secondary='container_items') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + +t_container_items = Table( + 'container_items', Base.metadata, + Column('item_id', ForeignKey('simple_items.id')), + Column('container_id', ForeignKey('simple_containers.id')) +) + """, + ) + + def test_manytomany_selfref(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + Table( + "child_items", + generator.metadata, + Column("parent_id", INTEGER), + Column("child_id", INTEGER), + ForeignKeyConstraint(["parent_id"], ["simple_items.id"]), + ForeignKeyConstraint(["child_id"], ["simple_items.id"]), + schema="otherschema", + ) + + validate_code( + generator.generate(), + """\ +from typing import List + +from sqlalchemy import Column, ForeignKey, Integer, Table +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + parent: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +secondary='otherschema.child_items', primaryjoin=lambda: SimpleItems.id \ +== t_child_items.c.child_id, \ +secondaryjoin=lambda: SimpleItems.id == \ +t_child_items.c.parent_id, back_populates='child') + child: Mapped[List['SimpleItems']] = \ +relationship('SimpleItems', secondary='otherschema.child_items', \ +primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \ +secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \ +back_populates='parent') + + +t_child_items = Table( + 'child_items', Base.metadata, + Column('parent_id', ForeignKey('simple_items.id')), + Column('child_id', ForeignKey('simple_items.id')), + schema='otherschema' +) + """, + ) + + def test_manytomany_composite(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id1", INTEGER, primary_key=True), + Column("id2", INTEGER, primary_key=True), + ) + Table( + "simple_containers", + generator.metadata, + Column("id1", INTEGER, primary_key=True), + Column("id2", INTEGER, primary_key=True), + ) + Table( + "container_items", + generator.metadata, + Column("item_id1", INTEGER), + Column("item_id2", INTEGER), + Column("container_id1", INTEGER), + Column("container_id2", INTEGER), + ForeignKeyConstraint( + ["item_id1", "item_id2"], ["simple_items.id1", "simple_items.id2"] + ), + ForeignKeyConstraint( + ["container_id1", "container_id2"], + ["simple_containers.id1", "simple_containers.id2"], + ), + ) + + validate_code( + generator.generate(), + """\ +from typing import List + +from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id1: Mapped[int] = mapped_column(Integer, primary_key=True) + id2: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +secondary='container_items', back_populates='simple_containers') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id1: Mapped[int] = mapped_column(Integer, primary_key=True) + id2: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_containers: Mapped[List['SimpleContainers']] = \ +relationship('SimpleContainers', secondary='container_items', \ +back_populates='simple_items') + + +t_container_items = Table( + 'container_items', Base.metadata, + Column('item_id1', Integer), + Column('item_id2', Integer), + Column('container_id1', Integer), + Column('container_id2', Integer), + ForeignKeyConstraint(['container_id1', 'container_id2'], \ +['simple_containers.id1', 'simple_containers.id2']), + ForeignKeyConstraint(['item_id1', 'item_id2'], \ +['simple_items.id1', 'simple_items.id2']) +) + """, + ) + + def test_joined_inheritance(self, generator: CodeGenerator) -> None: + Table( + "simple_sub_items", + generator.metadata, + Column("simple_items_id", INTEGER, primary_key=True), + Column("data3", INTEGER), + ForeignKeyConstraint(["simple_items_id"], ["simple_items.super_item_id"]), + ) + Table( + "simple_super_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("data1", INTEGER), + ) + Table( + "simple_items", + generator.metadata, + Column("super_item_id", INTEGER, primary_key=True), + Column("data2", INTEGER), + ForeignKeyConstraint(["super_item_id"], ["simple_super_items.id"]), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleSuperItems(Base): + __tablename__ = 'simple_super_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + data1: Mapped[Optional[int]] = mapped_column(Integer) + + +class SimpleItems(SimpleSuperItems): + __tablename__ = 'simple_items' + + super_item_id: Mapped[int] = mapped_column(ForeignKey('simple_super_items.id'), \ +primary_key=True) + data2: Mapped[Optional[int]] = mapped_column(Integer) + + +class SimpleSubItems(SimpleItems): + __tablename__ = 'simple_sub_items' + + simple_items_id: Mapped[int] = \ +mapped_column(ForeignKey('simple_items.super_item_id'), primary_key=True) + data3: Mapped[Optional[int]] = mapped_column(Integer) + """, + ) + + def test_joined_inheritance_same_table_name(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + Table( + "simple", + generator.metadata, + Column("id", INTEGER, ForeignKey("simple.id"), primary_key=True), + schema="altschema", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import ForeignKey, Integer + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + class Base(DeclarativeBase): + pass + + + class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + + class Simple_(Simple): + __tablename__ = 'simple' + __table_args__ = {'schema': 'altschema'} + + id: Mapped[int] = mapped_column(ForeignKey('simple.id'), primary_key=True) + """, + ) + + @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) + def test_use_inflect(self, generator: CodeGenerator) -> None: + Table( + "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + + Table("singular", generator.metadata, Column("id", INTEGER, primary_key=True)) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItem(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + +class Singular(Base): + __tablename__ = 'singular' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """, + ) + + @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) + @pytest.mark.parametrize( + argnames=("table_name", "class_name", "relationship_name"), + argvalues=[ + ("manufacturers", "manufacturer", "manufacturer"), + ("statuses", "status", "status"), + ("studies", "study", "study"), + ("moose", "moose", "moose"), + ], + ids=[ + "test_inflect_manufacturer", + "test_inflect_status", + "test_inflect_study", + "test_inflect_moose", + ], + ) + def test_use_inflect_plural( + self, + generator: CodeGenerator, + table_name: str, + class_name: str, + relationship_name: str, + ) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column(f"{relationship_name}_id", INTEGER), + ForeignKeyConstraint([f"{relationship_name}_id"], [f"{table_name}.id"]), + UniqueConstraint(f"{relationship_name}_id"), + ) + Table(table_name, generator.metadata, Column("id", INTEGER, primary_key=True)) + + validate_code( + generator.generate(), + f"""\ +from typing import Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class {class_name.capitalize()}(Base): + __tablename__ = '{table_name}' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_item: Mapped['SimpleItem'] = relationship('SimpleItem', uselist=False, \ +back_populates='{relationship_name}') + + +class SimpleItem(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + {relationship_name}_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('{table_name}.id'), unique=True) + + {relationship_name}: Mapped['{class_name.capitalize()}'] = \ +relationship('{class_name.capitalize()}', back_populates='simple_item') + """, + ) + + def test_table_kwargs(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + schema="testschema", + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = {'schema': 'testschema'} + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """, + ) + + def test_table_args_kwargs(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("name", VARCHAR), + schema="testschema", + ) + simple_items.indexes.add( + Index("testidx", simple_items.c.id, simple_items.c.name) + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Index, Integer, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + Index('testidx', 'id', 'name'), + {'schema': 'testschema'} + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[Optional[str]] = mapped_column(String) + """, + ) + + def test_foreign_key_schema(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("other_item_id", INTEGER), + ForeignKeyConstraint(["other_item_id"], ["otherschema.other_items.id"]), + ) + Table( + "other_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + schema="otherschema", + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKey, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class OtherItems(Base): + __tablename__ = 'other_items' + __table_args__ = {'schema': 'otherschema'} + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='other_item') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + other_item_id: Mapped[Optional[int]] = \ +mapped_column(ForeignKey('otherschema.other_items.id')) + + other_item: Mapped['OtherItems'] = relationship('OtherItems', \ +back_populates='simple_items') + """, + ) + + def test_invalid_attribute_names(self, generator: CodeGenerator) -> None: + Table( + "simple-items", + generator.metadata, + Column("id-test", INTEGER, primary_key=True), + Column("4test", INTEGER), + Column("_4test", INTEGER), + Column("def", INTEGER), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple-items' + + id_test: Mapped[int] = mapped_column('id-test', Integer, primary_key=True) + _4test: Mapped[Optional[int]] = mapped_column('4test', Integer) + _4test_: Mapped[Optional[int]] = mapped_column('_4test', Integer) + def_: Mapped[Optional[int]] = mapped_column('def', Integer) + """, + ) + + def test_pascal(self, generator: CodeGenerator) -> None: + Table( + "CustomerAPIPreference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class CustomerAPIPreference(Base): + __tablename__ = 'CustomerAPIPreference' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """, + ) + + def test_underscore(self, generator: CodeGenerator) -> None: + Table( + "customer_api_preference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class CustomerApiPreference(Base): + __tablename__ = 'customer_api_preference' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """, + ) + + def test_pascal_underscore(self, generator: CodeGenerator) -> None: + Table( + "customer_API_Preference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class CustomerAPIPreference(Base): + __tablename__ = 'customer_API_Preference' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """, + ) + + def test_pascal_multiple_underscore(self, generator: CodeGenerator) -> None: + Table( + "customer_API__Preference", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class CustomerAPIPreference(Base): + __tablename__ = 'customer_API__Preference' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + """, + ) + + @pytest.mark.parametrize( + "generator, nocomments", + [([], False), (["nocomments"], True)], + indirect=["generator"], + ) + def test_column_comment(self, generator: CodeGenerator, nocomments: bool) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True, comment="this is a 'comment'"), + ) + + comment_part = "" if nocomments else ", comment=\"this is a 'comment'\"" + validate_code( + generator.generate(), + f"""\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[int] = mapped_column(Integer, primary_key=True{comment_part}) +""", + ) + + def test_table_comment(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + comment="this is a 'comment'", + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + __table_args__ = {'comment': "this is a 'comment'"} + + id: Mapped[int] = mapped_column(Integer, primary_key=True) +""", + ) + + def test_metadata_column(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("metadata", VARCHAR), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Integer, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + metadata_: Mapped[Optional[str]] = mapped_column('metadata', String) +""", + ) + + def test_invalid_variable_name_from_column(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column(" id ", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[int] = mapped_column(' id ', Integer, primary_key=True) +""", + ) + + def test_only_tables(self, generator: CodeGenerator) -> None: + Table("simple", generator.metadata, Column("id", INTEGER)) + + validate_code( + generator.generate(), + """\ +from sqlalchemy import Column, Integer, MetaData, Table + +metadata = MetaData() + + +t_simple = Table( + 'simple', metadata, + Column('id', Integer) +) + """, + ) + + def test_named_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER), + Column("text", VARCHAR), + CheckConstraint("id > 0", name="checktest"), + PrimaryKeyConstraint("id", name="primarytest"), + UniqueConstraint("text", name="uniquetest"), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import CheckConstraint, Integer, PrimaryKeyConstraint, \ +String, UniqueConstraint +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + __table_args__ = ( + CheckConstraint('id > 0', name='checktest'), + PrimaryKeyConstraint('id', name='primarytest'), + UniqueConstraint('text', name='uniquetest') + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + text: Mapped[Optional[str]] = mapped_column(String) +""", + ) + + def test_named_foreign_key_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint( + ["container_id"], ["simple_containers.id"], name="foreignkeytest" + ), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ +from typing import List, Optional + +from sqlalchemy import ForeignKeyConstraint, Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + +class Base(DeclarativeBase): + pass + + +class SimpleContainers(Base): + __tablename__ = 'simple_containers' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + + simple_items: Mapped[List['SimpleItems']] = relationship('SimpleItems', \ +back_populates='container') + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = ( + ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ +name='foreignkeytest'), + ) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + container_id: Mapped[Optional[int]] = mapped_column(Integer) + + container: Mapped['SimpleContainers'] = relationship('SimpleContainers', \ +back_populates='simple_items') +""", + ) + + # @pytest.mark.xfail(strict=True) + def test_colname_import_conflict(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("text", VARCHAR), + Column("textwithdefault", VARCHAR, server_default=text("'test'")), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Integer, String, text +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class Simple(Base): + __tablename__ = 'simple' + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + text_: Mapped[Optional[str]] = mapped_column('text', String) + textwithdefault: Mapped[Optional[str]] = mapped_column(String, \ +server_default=text("'test'")) +""", + ) diff --git a/tests/test_generator_sqlmodel.py b/tests/test_generator_sqlmodel.py new file mode 100644 index 00000000..538400a7 --- /dev/null +++ b/tests/test_generator_sqlmodel.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy.engine import Engine +from sqlalchemy.schema import ( + CheckConstraint, + Column, + ForeignKeyConstraint, + Index, + MetaData, + Table, + UniqueConstraint, +) +from sqlalchemy.types import INTEGER, VARCHAR + +from sqlacodegen.generators import CodeGenerator, SQLModelGenerator + +from .conftest import requires_sqlmodel, validate_code + + +@requires_sqlmodel +class TestSQLModelGenerator: + @pytest.fixture + def generator( + self, request: FixtureRequest, metadata: MetaData, engine: Engine + ) -> CodeGenerator: + options = getattr(request, "param", []) + return SQLModelGenerator(metadata, engine, options) + + def test_indexes(self, generator: CodeGenerator) -> None: + simple_items = Table( + "item", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", INTEGER), + Column("text", VARCHAR), + ) + simple_items.indexes.add(Index("idx_number", simple_items.c.number)) + simple_items.indexes.add( + Index("idx_text_number", simple_items.c.text, simple_items.c.number) + ) + simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True)) + + validate_code( + generator.generate(), + """\ + from typing import Optional + + from sqlalchemy import Column, Index, Integer, String + from sqlmodel import Field, SQLModel + + class Item(SQLModel, table=True): + __table_args__ = ( + Index('idx_number', 'number'), + Index('idx_text', 'text', unique=True), + Index('idx_text_number', 'text', 'number') + ) + + id: Optional[int] = Field(default=None, sa_column=Column(\ +'id', Integer, primary_key=True)) + number: Optional[int] = Field(default=None, sa_column=Column(\ +'number', Integer)) + text: Optional[str] = Field(default=None, sa_column=Column(\ +'text', String)) + """, + ) + + def test_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_constraints", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", INTEGER), + CheckConstraint("number > 0"), + UniqueConstraint("id", "number"), + ) + + validate_code( + generator.generate(), + """\ + from typing import Optional + + from sqlalchemy import CheckConstraint, Column, Integer, UniqueConstraint + from sqlmodel import Field, SQLModel + + class SimpleConstraints(SQLModel, table=True): + __tablename__ = 'simple_constraints' + __table_args__ = ( + CheckConstraint('number > 0'), + UniqueConstraint('id', 'number') + ) + + id: Optional[int] = Field(default=None, sa_column=Column(\ +'id', Integer, primary_key=True)) + number: Optional[int] = Field(default=None, sa_column=Column(\ +'number', Integer)) + """, + ) + + def test_onetomany(self, generator: CodeGenerator) -> None: + Table( + "simple_goods", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("container_id", INTEGER), + ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), + ) + Table( + "simple_containers", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from typing import List, Optional + + from sqlalchemy import Column, ForeignKey, Integer + from sqlmodel import Field, Relationship, SQLModel + + class SimpleContainers(SQLModel, table=True): + __tablename__ = 'simple_containers' + + id: Optional[int] = Field(default=None, sa_column=Column(\ +'id', Integer, primary_key=True)) + + simple_goods: List['SimpleGoods'] = Relationship(\ +back_populates='container') + + + class SimpleGoods(SQLModel, table=True): + __tablename__ = 'simple_goods' + + id: Optional[int] = Field(default=None, sa_column=Column(\ +'id', Integer, primary_key=True)) + container_id: Optional[int] = Field(default=None, sa_column=Column(\ +'container_id', ForeignKey('simple_containers.id'))) + + container: Optional['SimpleContainers'] = Relationship(\ +back_populates='simple_goods') + """, + ) + + def test_onetoone(self, generator: CodeGenerator) -> None: + Table( + "simple_onetoone", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("other_item_id", INTEGER), + ForeignKeyConstraint(["other_item_id"], ["other_items.id"]), + UniqueConstraint("other_item_id"), + ) + Table( + "other_items", generator.metadata, Column("id", INTEGER, primary_key=True) + ) + + validate_code( + generator.generate(), + """\ + from typing import Optional + + from sqlalchemy import Column, ForeignKey, Integer + from sqlmodel import Field, Relationship, SQLModel + + class OtherItems(SQLModel, table=True): + __tablename__ = 'other_items' + + id: Optional[int] = Field(default=None, sa_column=Column(\ +'id', Integer, primary_key=True)) + + simple_onetoone: Optional['SimpleOnetoone'] = Relationship(\ +sa_relationship_kwargs={'uselist': False}, back_populates='other_item') + + + class SimpleOnetoone(SQLModel, table=True): + __tablename__ = 'simple_onetoone' + + id: Optional[int] = Field(default=None, sa_column=Column(\ +'id', Integer, primary_key=True)) + other_item_id: Optional[int] = Field(default=None, sa_column=Column(\ +'other_item_id', ForeignKey('other_items.id'), unique=True)) + + other_item: Optional['OtherItems'] = Relationship(\ +back_populates='simple_onetoone') + """, + ) diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py new file mode 100644 index 00000000..38a5da57 --- /dev/null +++ b/tests/test_generator_tables.py @@ -0,0 +1,951 @@ +from __future__ import annotations + +from textwrap import dedent + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.engine import Engine +from sqlalchemy.schema import ( + CheckConstraint, + Column, + Computed, + ForeignKey, + Identity, + Index, + MetaData, + Table, + UniqueConstraint, +) +from sqlalchemy.sql.expression import text +from sqlalchemy.sql.sqltypes import NullType +from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text + +from sqlacodegen.generators import CodeGenerator, TablesGenerator + +from .conftest import requires_sqlalchemy_1_4, validate_code + + +@requires_sqlalchemy_1_4 +class TestTablesGenerator: + @pytest.fixture + def generator( + self, request: FixtureRequest, metadata: MetaData, engine: Engine + ) -> CodeGenerator: + options = getattr(request, "param", []) + return TablesGenerator(metadata, engine, options) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_fancy_coltypes(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("enum", postgresql.ENUM("A", "B", name="blah")), + Column("bool", postgresql.BOOLEAN), + Column("number", NUMERIC(10, asdecimal=False)), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Boolean, Column, Enum, MetaData, Numeric, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('enum', Enum('A', 'B', name='blah')), + Column('bool', Boolean), + Column('number', Numeric(10, asdecimal=False)) + ) + """, + ) + + def test_boolean_detection(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("bool1", INTEGER), + Column("bool2", SMALLINT), + Column("bool3", mysql.TINYINT), + CheckConstraint("simple_items.bool1 IN (0, 1)"), + CheckConstraint("simple_items.bool2 IN (0, 1)"), + CheckConstraint("simple_items.bool3 IN (0, 1)"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Boolean, Column, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('bool1', Boolean), + Column('bool2', Boolean), + Column('bool3', Boolean) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_arrays(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "dp_array", postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)) + ), + Column("int_array", postgresql.ARRAY(INTEGER)), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import ARRAY, Column, Float, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('dp_array', ARRAY(Float(precision=53))), + Column('int_array', ARRAY(Integer())) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_jsonb(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("jsonb", postgresql.JSONB(astext_type=Text(50))), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, Table, Text + from sqlalchemy.dialects.postgresql import JSONB + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('jsonb', JSONB(astext_type=Text(length=50))) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_jsonb_default(self, generator: CodeGenerator) -> None: + Table("simple_items", generator.metadata, Column("jsonb", postgresql.JSONB)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, Table + from sqlalchemy.dialects.postgresql import JSONB + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('jsonb', JSONB) + ) + """, + ) + + def test_enum_detection(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("enum", VARCHAR(255)), + CheckConstraint(r"simple_items.enum IN ('A', '\'B', 'C')"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Enum, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('enum', Enum('A', "\\\\'B", 'C')) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_column_adaptation(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", postgresql.BIGINT), + Column("length", postgresql.DOUBLE_PRECISION), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import BigInteger, Column, Float, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', BigInteger), + Column('length', Float) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_column_types(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", mysql.INTEGER), + Column("name", mysql.VARCHAR(255)), + Column("set", mysql.SET("one", "two")), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, String, Table + from sqlalchemy.dialects.mysql import SET + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + Column('name', String(255)), + Column('set', SET('one', 'two')) + ) + """, + ) + + def test_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER), + Column("number", INTEGER), + CheckConstraint("number > 0"), + UniqueConstraint("id", "number"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table, \ +UniqueConstraint + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + Column('number', Integer), + CheckConstraint('number > 0'), + UniqueConstraint('id', 'number') + ) + """, + ) + + def test_indexes(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("id", INTEGER), + Column("number", INTEGER), + Column("text", VARCHAR), + Index("ix_empty"), + ) + simple_items.indexes.add(Index("ix_number", simple_items.c.number)) + simple_items.indexes.add( + Index( + "ix_text_number", + simple_items.c.text, + simple_items.c.number, + unique=True, + ) + ) + simple_items.indexes.add(Index("ix_text", simple_items.c.text, unique=True)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Index, Integer, MetaData, String, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + Column('number', Integer, index=True), + Column('text', String, unique=True, index=True), + Index('ix_empty'), + Index('ix_text_number', 'text', 'number', unique=True) + ) + """, + ) + + def test_table_comment(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + comment="this is a 'comment'", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple = Table( + 'simple', metadata, + Column('id', Integer, primary_key=True), + comment="this is a 'comment'" + ) + """, + ) + + def test_table_name_identifiers(self, generator: CodeGenerator) -> None: + Table( + "simple-items table", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items_table = Table( + 'simple-items table', metadata, + Column('id', Integer, primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize("generator", [["noindexes"]], indirect=True) + def test_option_noindexes(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("number", INTEGER), + CheckConstraint("number > 2"), + ) + simple_items.indexes.add(Index("idx_number", simple_items.c.number)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('number', Integer), + CheckConstraint('number > 2') + ) + """, + ) + + @pytest.mark.parametrize("generator", [["noconstraints"]], indirect=True) + def test_option_noconstraints(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("number", INTEGER), + CheckConstraint("number > 2"), + ) + simple_items.indexes.add(Index("ix_number", simple_items.c.number)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('number', Integer, index=True) + ) + """, + ) + + @pytest.mark.parametrize("generator", [["nocomments"]], indirect=True) + def test_option_nocomments(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True, comment="pk column comment"), + comment="this is a 'comment'", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple = Table( + 'simple', metadata, + Column('id', Integer, primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize( + "persisted, extra_args", + [(None, ""), (False, ", persisted=False"), (True, ", persisted=True")], + ) + def test_computed_column( + self, generator: CodeGenerator, persisted: bool | None, extra_args: str + ) -> None: + Table( + "computed", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("computed", INTEGER, Computed("1 + 2", persisted=persisted)), + ) + + validate_code( + generator.generate(), + f"""\ + from sqlalchemy import Column, Computed, Integer, MetaData, Table + + metadata = MetaData() + + + t_computed = Table( + 'computed', metadata, + Column('id', Integer, primary_key=True), + Column('computed', Integer, Computed('1 + 2'{extra_args})) + ) + """, + ) + + def test_schema(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("name", VARCHAR), + schema="testschema", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, String, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('name', String), + schema='testschema' + ) + """, + ) + + def test_foreign_key_options(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "name", + VARCHAR, + ForeignKey( + "simple_items.name", + ondelete="CASCADE", + onupdate="CASCADE", + deferrable=True, + initially="DEFERRED", + ), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, ForeignKey, MetaData, String, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('name', String, ForeignKey('simple_items.name', \ +ondelete='CASCADE', onupdate='CASCADE', deferrable=True, initially='DEFERRED')) + ) + """, + ) + + def test_pk_default(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text("uuid_generate_v4()"), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table, text + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True, \ +server_default=text('uuid_generate_v4()')) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_timestamp(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("timestamp", mysql.TIMESTAMP), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, TIMESTAMP, Table + + metadata = MetaData() + + + t_simple = Table( + 'simple', metadata, + Column('id', Integer, primary_key=True), + Column('timestamp', TIMESTAMP) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_integer_display_width(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", mysql.INTEGER(11)), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import INTEGER + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('number', INTEGER(11)) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_tinytext(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("my_tinytext", mysql.TINYTEXT), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import TINYTEXT + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('my_tinytext', TINYTEXT) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_mediumtext(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("my_mediumtext", mysql.MEDIUMTEXT), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import MEDIUMTEXT + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('my_mediumtext', MEDIUMTEXT) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_longtext(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("my_longtext", mysql.LONGTEXT), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import LONGTEXT + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('my_longtext', LONGTEXT) + ) + """, + ) + + def test_schema_boolean(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("bool1", INTEGER), + CheckConstraint("testschema.simple_items.bool1 IN (0, 1)"), + schema="testschema", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Boolean, Column, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('bool1', Boolean), + schema='testschema' + ) + """, + ) + + def test_server_default_multiline(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text( + dedent( + """\ + /*Comment*/ + /*Next line*/ + something()""" + ) + ), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table, text + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True, server_default=\ +text('/*Comment*/\\n/*Next line*/\\nsomething()')) + ) + """, + ) + + def test_server_default_colon(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("problem", VARCHAR, server_default=text("':001'")), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, String, Table, text + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('problem', String, server_default=text("':001'")) + ) + """, + ) + + def test_null_type(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("problem", NullType), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, Table + from sqlalchemy.sql.sqltypes import NullType + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('problem', NullType) + ) + """, + ) + + def test_identity_column(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=Identity(start=1, increment=2), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Identity, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, Identity(start=1, increment=2), primary_key=True) + ) + """, + ) + + def test_multiline_column_comment(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, comment="This\nis a multi-line\ncomment"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, comment='This\\nis a multi-line\\ncomment') + ) + """, + ) + + def test_multiline_table_comment(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER), + comment="This\nis a multi-line\ncomment", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + comment='This\\nis a multi-line\\ncomment' + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_postgresql_sequence_standard_name(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text("nextval('simple_items_id_seq'::regclass)"), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_postgresql_sequence_nonstandard_name( + self, generator: CodeGenerator + ) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text("nextval('test_seq'::regclass)"), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Sequence, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, Sequence('test_seq'), primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize( + "schemaname, seqname", + [ + pytest.param("myschema", "test_seq"), + pytest.param("myschema", '"test_seq"'), + pytest.param('"my.schema"', "test_seq"), + pytest.param('"my.schema"', '"test_seq"'), + ], + ) + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_postgresql_sequence_with_schema( + self, generator: CodeGenerator, schemaname: str, seqname: str + ) -> None: + expected_schema = schemaname.strip('"') + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text(f"nextval('{schemaname}.{seqname}'::regclass)"), + ), + schema=expected_schema, + ) + + validate_code( + generator.generate(), + f"""\ + from sqlalchemy import Column, Integer, MetaData, Sequence, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, Sequence('test_seq', \ +schema='{expected_schema}'), primary_key=True), + schema='{expected_schema}' + ) + """, + ) diff --git a/tests/test_generator_tables2.py b/tests/test_generator_tables2.py new file mode 100644 index 00000000..796b4e82 --- /dev/null +++ b/tests/test_generator_tables2.py @@ -0,0 +1,951 @@ +from __future__ import annotations + +from textwrap import dedent + +import pytest +from _pytest.fixtures import FixtureRequest +from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.engine import Engine +from sqlalchemy.schema import ( + CheckConstraint, + Column, + Computed, + ForeignKey, + Identity, + Index, + MetaData, + Table, + UniqueConstraint, +) +from sqlalchemy.sql.expression import text +from sqlalchemy.sql.sqltypes import NullType +from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text + +from sqlacodegen.generators import CodeGenerator, TablesGenerator + +from .conftest import requires_sqlalchemy_2_0, validate_code + + +@requires_sqlalchemy_2_0 +class TestTablesGenerator: + @pytest.fixture + def generator( + self, request: FixtureRequest, metadata: MetaData, engine: Engine + ) -> CodeGenerator: + options = getattr(request, "param", []) + return TablesGenerator(metadata, engine, options) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_fancy_coltypes(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("enum", postgresql.ENUM("A", "B", name="blah")), + Column("bool", postgresql.BOOLEAN), + Column("number", NUMERIC(10, asdecimal=False)), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Boolean, Column, Enum, MetaData, Numeric, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('enum', Enum('A', 'B', name='blah')), + Column('bool', Boolean), + Column('number', Numeric(10, asdecimal=False)) + ) + """, + ) + + def test_boolean_detection(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("bool1", INTEGER), + Column("bool2", SMALLINT), + Column("bool3", mysql.TINYINT), + CheckConstraint("simple_items.bool1 IN (0, 1)"), + CheckConstraint("simple_items.bool2 IN (0, 1)"), + CheckConstraint("simple_items.bool3 IN (0, 1)"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Boolean, Column, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('bool1', Boolean), + Column('bool2', Boolean), + Column('bool3', Boolean) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_arrays(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "dp_array", postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)) + ), + Column("int_array", postgresql.ARRAY(INTEGER)), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import ARRAY, Column, Double, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('dp_array', ARRAY(Double(precision=53))), + Column('int_array', ARRAY(Integer())) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_jsonb(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("jsonb", postgresql.JSONB(astext_type=Text(50))), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, Table, Text + from sqlalchemy.dialects.postgresql import JSONB + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('jsonb', JSONB(astext_type=Text(length=50))) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_jsonb_default(self, generator: CodeGenerator) -> None: + Table("simple_items", generator.metadata, Column("jsonb", postgresql.JSONB)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, Table + from sqlalchemy.dialects.postgresql import JSONB + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('jsonb', JSONB) + ) + """, + ) + + def test_enum_detection(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("enum", VARCHAR(255)), + CheckConstraint(r"simple_items.enum IN ('A', '\'B', 'C')"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Enum, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('enum', Enum('A', "\\\\'B", 'C')) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_column_adaptation(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", postgresql.BIGINT), + Column("length", postgresql.DOUBLE_PRECISION), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import BigInteger, Column, Double, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', BigInteger), + Column('length', Double) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_column_types(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", mysql.INTEGER), + Column("name", mysql.VARCHAR(255)), + Column("set", mysql.SET("one", "two")), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, String, Table + from sqlalchemy.dialects.mysql import SET + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + Column('name', String(255)), + Column('set', SET('one', 'two')) + ) + """, + ) + + def test_constraints(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER), + Column("number", INTEGER), + CheckConstraint("number > 0"), + UniqueConstraint("id", "number"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table, \ +UniqueConstraint + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + Column('number', Integer), + CheckConstraint('number > 0'), + UniqueConstraint('id', 'number') + ) + """, + ) + + def test_indexes(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("id", INTEGER), + Column("number", INTEGER), + Column("text", VARCHAR), + Index("ix_empty"), + ) + simple_items.indexes.add(Index("ix_number", simple_items.c.number)) + simple_items.indexes.add( + Index( + "ix_text_number", + simple_items.c.text, + simple_items.c.number, + unique=True, + ) + ) + simple_items.indexes.add(Index("ix_text", simple_items.c.text, unique=True)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Index, Integer, MetaData, String, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + Column('number', Integer, index=True), + Column('text', String, unique=True, index=True), + Index('ix_empty'), + Index('ix_text_number', 'text', 'number', unique=True) + ) + """, + ) + + def test_table_comment(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + comment="this is a 'comment'", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple = Table( + 'simple', metadata, + Column('id', Integer, primary_key=True), + comment="this is a 'comment'" + ) + """, + ) + + def test_table_name_identifiers(self, generator: CodeGenerator) -> None: + Table( + "simple-items table", + generator.metadata, + Column("id", INTEGER, primary_key=True), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items_table = Table( + 'simple-items table', metadata, + Column('id', Integer, primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize("generator", [["noindexes"]], indirect=True) + def test_option_noindexes(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("number", INTEGER), + CheckConstraint("number > 2"), + ) + simple_items.indexes.add(Index("idx_number", simple_items.c.number)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('number', Integer), + CheckConstraint('number > 2') + ) + """, + ) + + @pytest.mark.parametrize("generator", [["noconstraints"]], indirect=True) + def test_option_noconstraints(self, generator: CodeGenerator) -> None: + simple_items = Table( + "simple_items", + generator.metadata, + Column("number", INTEGER), + CheckConstraint("number > 2"), + ) + simple_items.indexes.add(Index("ix_number", simple_items.c.number)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('number', Integer, index=True) + ) + """, + ) + + @pytest.mark.parametrize("generator", [["nocomments"]], indirect=True) + def test_option_nocomments(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True, comment="pk column comment"), + comment="this is a 'comment'", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple = Table( + 'simple', metadata, + Column('id', Integer, primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize( + "persisted, extra_args", + [(None, ""), (False, ", persisted=False"), (True, ", persisted=True")], + ) + def test_computed_column( + self, generator: CodeGenerator, persisted: bool | None, extra_args: str + ) -> None: + Table( + "computed", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("computed", INTEGER, Computed("1 + 2", persisted=persisted)), + ) + + validate_code( + generator.generate(), + f"""\ + from sqlalchemy import Column, Computed, Integer, MetaData, Table + + metadata = MetaData() + + + t_computed = Table( + 'computed', metadata, + Column('id', Integer, primary_key=True), + Column('computed', Integer, Computed('1 + 2'{extra_args})) + ) + """, + ) + + def test_schema(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("name", VARCHAR), + schema="testschema", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, String, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('name', String), + schema='testschema' + ) + """, + ) + + def test_foreign_key_options(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "name", + VARCHAR, + ForeignKey( + "simple_items.name", + ondelete="CASCADE", + onupdate="CASCADE", + deferrable=True, + initially="DEFERRED", + ), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, ForeignKey, MetaData, String, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('name', String, ForeignKey('simple_items.name', \ +ondelete='CASCADE', onupdate='CASCADE', deferrable=True, initially='DEFERRED')) + ) + """, + ) + + def test_pk_default(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text("uuid_generate_v4()"), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table, text + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True, \ +server_default=text('uuid_generate_v4()')) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_timestamp(self, generator: CodeGenerator) -> None: + Table( + "simple", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("timestamp", mysql.TIMESTAMP), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, TIMESTAMP, Table + + metadata = MetaData() + + + t_simple = Table( + 'simple', metadata, + Column('id', Integer, primary_key=True), + Column('timestamp', TIMESTAMP) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_integer_display_width(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("number", mysql.INTEGER(11)), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import INTEGER + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('number', INTEGER(11)) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_tinytext(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("my_tinytext", mysql.TINYTEXT), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import TINYTEXT + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('my_tinytext', TINYTEXT) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_mediumtext(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("my_mediumtext", mysql.MEDIUMTEXT), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import MEDIUMTEXT + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('my_mediumtext', MEDIUMTEXT) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) + def test_mysql_longtext(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, primary_key=True), + Column("my_longtext", mysql.LONGTEXT), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + from sqlalchemy.dialects.mysql import LONGTEXT + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True), + Column('my_longtext', LONGTEXT) + ) + """, + ) + + def test_schema_boolean(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("bool1", INTEGER), + CheckConstraint("testschema.simple_items.bool1 IN (0, 1)"), + schema="testschema", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Boolean, Column, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('bool1', Boolean), + schema='testschema' + ) + """, + ) + + def test_server_default_multiline(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text( + dedent( + """\ + /*Comment*/ + /*Next line*/ + something()""" + ) + ), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table, text + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True, server_default=\ +text('/*Comment*/\\n/*Next line*/\\nsomething()')) + ) + """, + ) + + def test_server_default_colon(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("problem", VARCHAR, server_default=text("':001'")), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, String, Table, text + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('problem', String, server_default=text("':001'")) + ) + """, + ) + + def test_null_type(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("problem", NullType), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, Table + from sqlalchemy.sql.sqltypes import NullType + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('problem', NullType) + ) + """, + ) + + def test_identity_column(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=Identity(start=1, increment=2), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Identity, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, Identity(start=1, increment=2), primary_key=True) + ) + """, + ) + + def test_multiline_column_comment(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER, comment="This\nis a multi-line\ncomment"), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, comment='This\\nis a multi-line\\ncomment') + ) + """, + ) + + def test_multiline_table_comment(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id", INTEGER), + comment="This\nis a multi-line\ncomment", + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer), + comment='This\\nis a multi-line\\ncomment' + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_postgresql_sequence_standard_name(self, generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text("nextval('simple_items_id_seq'::regclass)"), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_postgresql_sequence_nonstandard_name( + self, generator: CodeGenerator + ) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text("nextval('test_seq'::regclass)"), + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, Integer, MetaData, Sequence, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, Sequence('test_seq'), primary_key=True) + ) + """, + ) + + @pytest.mark.parametrize( + "schemaname, seqname", + [ + pytest.param("myschema", "test_seq"), + pytest.param("myschema", '"test_seq"'), + pytest.param('"my.schema"', "test_seq"), + pytest.param('"my.schema"', '"test_seq"'), + ], + ) + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) + def test_postgresql_sequence_with_schema( + self, generator: CodeGenerator, schemaname: str, seqname: str + ) -> None: + expected_schema = schemaname.strip('"') + Table( + "simple_items", + generator.metadata, + Column( + "id", + INTEGER, + primary_key=True, + server_default=text(f"nextval('{schemaname}.{seqname}'::regclass)"), + ), + schema=expected_schema, + ) + + validate_code( + generator.generate(), + f"""\ + from sqlalchemy import Column, Integer, MetaData, Sequence, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('id', Integer, Sequence('test_seq', \ +schema='{expected_schema}'), primary_key=True), + schema='{expected_schema}' + ) + """, + ) diff --git a/tests/test_generators.py b/tests/test_generators.py deleted file mode 100644 index edc78dba..00000000 --- a/tests/test_generators.py +++ /dev/null @@ -1,2805 +0,0 @@ -from textwrap import dedent - -import pytest -from pytest import FixtureRequest -from sqlalchemy import PrimaryKeyConstraint -from sqlalchemy.dialects import mysql, postgresql -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.engine import Engine, create_engine -from sqlalchemy.orm import clear_mappers, configure_mappers -from sqlalchemy.schema import ( - CheckConstraint, - Column, - Computed, - ForeignKey, - ForeignKeyConstraint, - Identity, - Index, - MetaData, - Table, - UniqueConstraint, -) -from sqlalchemy.sql.expression import text -from sqlalchemy.sql.sqltypes import NullType -from sqlalchemy.types import INTEGER, NUMERIC, SMALLINT, VARCHAR, Text - -from sqlacodegen.generators import ( - CodeGenerator, - DataclassGenerator, - DeclarativeGenerator, - SQLModelGenerator, - TablesGenerator, -) - - -def validate_code(generated_code: str, expected_code: str) -> None: - expected_code = dedent(expected_code) - assert generated_code == expected_code - try: - exec(generated_code, {}) - configure_mappers() - finally: - clear_mappers() - - -@pytest.fixture -def engine(request: FixtureRequest) -> Engine: - dialect = getattr(request, "param", None) - if dialect == "postgresql": - return create_engine("postgresql:///testdb") - elif dialect == "mysql": - return create_engine("mysql+mysqlconnector://testdb") - else: - return create_engine("sqlite:///:memory:") - - -@pytest.fixture -def metadata() -> MetaData: - return MetaData() - - -class TestTablesGenerator: - @pytest.fixture - def generator( - self, request: FixtureRequest, metadata: MetaData, engine: Engine - ) -> CodeGenerator: - options = getattr(request, "param", []) - return TablesGenerator(metadata, engine, options) - - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_fancy_coltypes(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("enum", postgresql.ENUM("A", "B", name="blah")), - Column("bool", postgresql.BOOLEAN), - Column("number", NUMERIC(10, asdecimal=False)), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Boolean, Column, Enum, MetaData, Numeric, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('enum', Enum('A', 'B', name='blah')), - Column('bool', Boolean), - Column('number', Numeric(10, asdecimal=False)) - ) - """, - ) - - def test_boolean_detection(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("bool1", INTEGER), - Column("bool2", SMALLINT), - Column("bool3", mysql.TINYINT), - CheckConstraint("simple_items.bool1 IN (0, 1)"), - CheckConstraint("simple_items.bool2 IN (0, 1)"), - CheckConstraint("simple_items.bool3 IN (0, 1)"), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Boolean, Column, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('bool1', Boolean), - Column('bool2', Boolean), - Column('bool3', Boolean) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_arrays(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column( - "dp_array", postgresql.ARRAY(postgresql.DOUBLE_PRECISION(precision=53)) - ), - Column("int_array", postgresql.ARRAY(INTEGER)), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import ARRAY, Column, Float, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('dp_array', ARRAY(Float(precision=53))), - Column('int_array', ARRAY(Integer())) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_jsonb(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("jsonb", postgresql.JSONB(astext_type=Text(50))), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, MetaData, Table, Text - from sqlalchemy.dialects.postgresql import JSONB - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('jsonb', JSONB(astext_type=Text(length=50))) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_jsonb_default(self, generator: CodeGenerator) -> None: - Table("simple_items", generator.metadata, Column("jsonb", postgresql.JSONB)) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, MetaData, Table - from sqlalchemy.dialects.postgresql import JSONB - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('jsonb', JSONB) - ) - """, - ) - - def test_enum_detection(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("enum", VARCHAR(255)), - CheckConstraint(r"simple_items.enum IN ('A', '\'B', 'C')"), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Enum, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('enum', Enum('A', "\\\\'B", 'C')) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_column_adaptation(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", postgresql.BIGINT), - Column("length", postgresql.DOUBLE_PRECISION), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import BigInteger, Column, Float, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', BigInteger), - Column('length', Float) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) - def test_mysql_column_types(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", mysql.INTEGER), - Column("name", mysql.VARCHAR(255)), - Column("set", mysql.SET("one", "two")), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, String, Table - from sqlalchemy.dialects.mysql import SET - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer), - Column('name', String(255)), - Column('set', SET('one', 'two')) - ) - """, - ) - - def test_constraints(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER), - Column("number", INTEGER), - CheckConstraint("number > 0"), - UniqueConstraint("id", "number"), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table, \ -UniqueConstraint - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer), - Column('number', Integer), - CheckConstraint('number > 0'), - UniqueConstraint('id', 'number') - ) - """, - ) - - def test_indexes(self, generator: CodeGenerator) -> None: - simple_items = Table( - "simple_items", - generator.metadata, - Column("id", INTEGER), - Column("number", INTEGER), - Column("text", VARCHAR), - Index("ix_empty"), - ) - simple_items.indexes.add(Index("ix_number", simple_items.c.number)) - simple_items.indexes.add( - Index( - "ix_text_number", - simple_items.c.text, - simple_items.c.number, - unique=True, - ) - ) - simple_items.indexes.add(Index("ix_text", simple_items.c.text, unique=True)) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Index, Integer, MetaData, String, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer), - Column('number', Integer, index=True), - Column('text', String, unique=True, index=True), - Index('ix_empty'), - Index('ix_text_number', 'text', 'number', unique=True) - ) - """, - ) - - def test_table_comment(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - comment="this is a 'comment'", - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple = Table( - 'simple', metadata, - Column('id', Integer, primary_key=True), - comment="this is a 'comment'" - ) - """, - ) - - def test_table_name_identifiers(self, generator: CodeGenerator) -> None: - Table( - "simple-items table", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items_table = Table( - 'simple-items table', metadata, - Column('id', Integer, primary_key=True) - ) - """, - ) - - @pytest.mark.parametrize("generator", [["noindexes"]], indirect=True) - def test_option_noindexes(self, generator: CodeGenerator) -> None: - simple_items = Table( - "simple_items", - generator.metadata, - Column("number", INTEGER), - CheckConstraint("number > 2"), - ) - simple_items.indexes.add(Index("idx_number", simple_items.c.number)) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('number', Integer), - CheckConstraint('number > 2') - ) - """, - ) - - @pytest.mark.parametrize("generator", [["noconstraints"]], indirect=True) - def test_option_noconstraints(self, generator: CodeGenerator) -> None: - simple_items = Table( - "simple_items", - generator.metadata, - Column("number", INTEGER), - CheckConstraint("number > 2"), - ) - simple_items.indexes.add(Index("ix_number", simple_items.c.number)) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('number', Integer, index=True) - ) - """, - ) - - @pytest.mark.parametrize("generator", [["nocomments"]], indirect=True) - def test_option_nocomments(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True, comment="pk column comment"), - comment="this is a 'comment'", - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple = Table( - 'simple', metadata, - Column('id', Integer, primary_key=True) - ) - """, - ) - - @pytest.mark.parametrize( - "persisted, extra_args", - [(None, ""), (False, ", persisted=False"), (True, ", persisted=True")], - ) - def test_computed_column( - self, generator: CodeGenerator, persisted: "bool | None", extra_args: str - ) -> None: - Table( - "computed", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("computed", INTEGER, Computed("1 + 2", persisted=persisted)), - ) - - validate_code( - generator.generate(), - f"""\ - from sqlalchemy import Column, Computed, Integer, MetaData, Table - - metadata = MetaData() - - - t_computed = Table( - 'computed', metadata, - Column('id', Integer, primary_key=True), - Column('computed', Integer, Computed('1 + 2'{extra_args})) - ) - """, - ) - - def test_schema(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("name", VARCHAR), - schema="testschema", - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, MetaData, String, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('name', String), - schema='testschema' - ) - """, - ) - - def test_foreign_key_options(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column( - "name", - VARCHAR, - ForeignKey( - "simple_items.name", - ondelete="CASCADE", - onupdate="CASCADE", - deferrable=True, - initially="DEFERRED", - ), - ), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, ForeignKey, MetaData, String, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('name', String, ForeignKey('simple_items.name', \ -ondelete='CASCADE', onupdate='CASCADE', deferrable=True, initially='DEFERRED')) - ) - """, - ) - - def test_pk_default(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column( - "id", - INTEGER, - primary_key=True, - server_default=text("uuid_generate_v4()"), - ), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table, text - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, primary_key=True, \ -server_default=text('uuid_generate_v4()')) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) - def test_mysql_timestamp(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("timestamp", mysql.TIMESTAMP), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, TIMESTAMP, Table - - metadata = MetaData() - - - t_simple = Table( - 'simple', metadata, - Column('id', Integer, primary_key=True), - Column('timestamp', TIMESTAMP) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) - def test_mysql_integer_display_width(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("number", mysql.INTEGER(11)), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - from sqlalchemy.dialects.mysql import INTEGER - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, primary_key=True), - Column('number', INTEGER(11)) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) - def test_mysql_tinytext(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("my_tinytext", mysql.TINYTEXT), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - from sqlalchemy.dialects.mysql import TINYTEXT - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, primary_key=True), - Column('my_tinytext', TINYTEXT) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) - def test_mysql_mediumtext(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("my_mediumtext", mysql.MEDIUMTEXT), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - from sqlalchemy.dialects.mysql import MEDIUMTEXT - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, primary_key=True), - Column('my_mediumtext', MEDIUMTEXT) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) - def test_mysql_longtext(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("my_longtext", mysql.LONGTEXT), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - from sqlalchemy.dialects.mysql import LONGTEXT - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, primary_key=True), - Column('my_longtext', LONGTEXT) - ) - """, - ) - - def test_schema_boolean(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("bool1", INTEGER), - CheckConstraint("testschema.simple_items.bool1 IN (0, 1)"), - schema="testschema", - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Boolean, Column, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('bool1', Boolean), - schema='testschema' - ) - """, - ) - - def test_server_default_multiline(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column( - "id", - INTEGER, - primary_key=True, - server_default=text( - dedent( - """\ - /*Comment*/ - /*Next line*/ - something()""" - ) - ), - ), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table, text - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, primary_key=True, server_default=\ -text('/*Comment*/\\n/*Next line*/\\nsomething()')) - ) - """, - ) - - def test_server_default_colon(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("problem", VARCHAR, server_default=text("':001'")), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, MetaData, String, Table, text - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('problem', String, server_default=text("':001'")) - ) - """, - ) - - def test_null_type(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("problem", NullType), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, MetaData, Table - from sqlalchemy.sql.sqltypes import NullType - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('problem', NullType) - ) - """, - ) - - def test_identity_column(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column( - "id", - INTEGER, - primary_key=True, - server_default=Identity(start=1, increment=2), - ), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Identity, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, Identity(start=1, increment=2), primary_key=True) - ) - """, - ) - - def test_multiline_column_comment(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, comment="This\nis a multi-line\ncomment"), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, comment='This\\nis a multi-line\\ncomment') - ) - """, - ) - - def test_multiline_table_comment(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER), - comment="This\nis a multi-line\ncomment", - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer), - comment='This\\nis a multi-line\\ncomment' - ) - """, - ) - - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_postgresql_sequence_standard_name(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column( - "id", - INTEGER, - primary_key=True, - server_default=text("nextval('simple_items_id_seq'::regclass)"), - ), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, primary_key=True) - ) - """, - ) - - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_postgresql_sequence_nonstandard_name( - self, generator: CodeGenerator - ) -> None: - Table( - "simple_items", - generator.metadata, - Column( - "id", - INTEGER, - primary_key=True, - server_default=text("nextval('test_seq'::regclass)"), - ), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Sequence, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, Sequence('test_seq'), primary_key=True) - ) - """, - ) - - @pytest.mark.parametrize( - "schemaname, seqname", - [ - pytest.param("myschema", "test_seq"), - pytest.param("myschema", '"test_seq"'), - pytest.param('"my.schema"', "test_seq"), - pytest.param('"my.schema"', '"test_seq"'), - ], - ) - @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) - def test_postgresql_sequence_with_schema( - self, generator: CodeGenerator, schemaname: str, seqname: str - ) -> None: - expected_schema = schemaname.strip('"') - Table( - "simple_items", - generator.metadata, - Column( - "id", - INTEGER, - primary_key=True, - server_default=text(f"nextval('{schemaname}.{seqname}'::regclass)"), - ), - schema=expected_schema, - ) - - validate_code( - generator.generate(), - f"""\ - from sqlalchemy import Column, Integer, MetaData, Sequence, Table - - metadata = MetaData() - - - t_simple_items = Table( - 'simple_items', metadata, - Column('id', Integer, Sequence('test_seq', \ -schema='{expected_schema}'), primary_key=True), - schema='{expected_schema}' - ) - """, - ) - - -class TestDeclarativeGenerator: - @pytest.fixture - def generator( - self, request: FixtureRequest, metadata: MetaData, engine: Engine - ) -> CodeGenerator: - options = getattr(request, "param", []) - return DeclarativeGenerator(metadata, engine, options) - - def test_indexes(self, generator: CodeGenerator) -> None: - simple_items = Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("number", INTEGER), - Column("text", VARCHAR), - ) - simple_items.indexes.add(Index("idx_number", simple_items.c.number)) - simple_items.indexes.add( - Index("idx_text_number", simple_items.c.text, simple_items.c.number) - ) - simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True)) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Index, Integer, String - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class SimpleItems(Base): - __tablename__ = 'simple_items' - __table_args__ = ( - Index('idx_number', 'number'), - Index('idx_text', 'text', unique=True), - Index('idx_text_number', 'text', 'number') - ) - - id = Column(Integer, primary_key=True) - number = Column(Integer) - text = Column(String) - """, - ) - - def test_constraints(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("number", INTEGER), - CheckConstraint("number > 0"), - UniqueConstraint("id", "number"), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import CheckConstraint, Column, Integer, UniqueConstraint - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class SimpleItems(Base): - __tablename__ = 'simple_items' - __table_args__ = ( - CheckConstraint('number > 0'), - UniqueConstraint('id', 'number') - ) - - id = Column(Integer, primary_key=True) - number = Column(Integer) - """, - ) - - def test_onetomany(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id", INTEGER), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, ForeignKey, Integer - from sqlalchemy.orm import declarative_base, relationship - - Base = declarative_base() - - - class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id = Column(Integer, primary_key=True) - - simple_items = relationship('SimpleItems', back_populates='container') - - - class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - container_id = Column(ForeignKey('simple_containers.id')) - - container = relationship('SimpleContainers', \ -back_populates='simple_items') - """, - ) - - def test_onetomany_selfref(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("parent_item_id", INTEGER), - ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, ForeignKey, Integer - from sqlalchemy.orm import declarative_base, relationship - - Base = declarative_base() - - - class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - parent_item_id = Column(ForeignKey('simple_items.id')) - - parent_item = relationship('SimpleItems', remote_side=[id], \ -back_populates='parent_item_reverse') - parent_item_reverse = relationship('SimpleItems', \ -remote_side=[parent_item_id], back_populates='parent_item') - """, - ) - - def test_onetomany_selfref_multi(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("parent_item_id", INTEGER), - Column("top_item_id", INTEGER), - ForeignKeyConstraint(["parent_item_id"], ["simple_items.id"]), - ForeignKeyConstraint(["top_item_id"], ["simple_items.id"]), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, ForeignKey, Integer - from sqlalchemy.orm import declarative_base, relationship - - Base = declarative_base() - - - class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - parent_item_id = Column(ForeignKey('simple_items.id')) - top_item_id = Column(ForeignKey('simple_items.id')) - - parent_item = relationship('SimpleItems', remote_side=[id], \ -foreign_keys=[parent_item_id], back_populates='parent_item_reverse') - parent_item_reverse = relationship('SimpleItems', \ -remote_side=[parent_item_id], foreign_keys=[parent_item_id], \ -back_populates='parent_item') - top_item = relationship('SimpleItems', remote_side=[id], \ -foreign_keys=[top_item_id], back_populates='top_item_reverse') - top_item_reverse = relationship('SimpleItems', \ -remote_side=[top_item_id], foreign_keys=[top_item_id], back_populates='top_item') - """, - ) - - def test_onetomany_composite(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id1", INTEGER), - Column("container_id2", INTEGER), - ForeignKeyConstraint( - ["container_id1", "container_id2"], - ["simple_containers.id1", "simple_containers.id2"], - ondelete="CASCADE", - onupdate="CASCADE", - ), - ) - Table( - "simple_containers", - generator.metadata, - Column("id1", INTEGER, primary_key=True), - Column("id2", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKeyConstraint, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id1 = Column(Integer, primary_key=True, nullable=False) - id2 = Column(Integer, primary_key=True, nullable=False) - - simple_items = relationship('SimpleItems', back_populates='simple_containers') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - __table_args__ = ( - ForeignKeyConstraint(['container_id1', 'container_id2'], \ -['simple_containers.id1', 'simple_containers.id2'], ondelete='CASCADE', \ -onupdate='CASCADE'), - ) - - id = Column(Integer, primary_key=True) - container_id1 = Column(Integer) - container_id2 = Column(Integer) - - simple_containers = relationship('SimpleContainers', back_populates='simple_items') - """, - ) - - def test_onetomany_multiref(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("parent_container_id", INTEGER), - Column("top_container_id", INTEGER), - ForeignKeyConstraint(["parent_container_id"], ["simple_containers.id"]), - ForeignKeyConstraint(["top_container_id"], ["simple_containers.id"]), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id = Column(Integer, primary_key=True) - - simple_items = relationship('SimpleItems', \ -foreign_keys='[SimpleItems.parent_container_id]', back_populates='parent_container') - simple_items_ = relationship('SimpleItems', \ -foreign_keys='[SimpleItems.top_container_id]', back_populates='top_container') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - parent_container_id = Column(ForeignKey('simple_containers.id')) - top_container_id = Column(ForeignKey('simple_containers.id')) - - parent_container = relationship('SimpleContainers', \ -foreign_keys=[parent_container_id], back_populates='simple_items') - top_container = relationship('SimpleContainers', \ -foreign_keys=[top_container_id], back_populates='simple_items_') - """, - ) - - def test_onetoone(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("other_item_id", INTEGER), - ForeignKeyConstraint(["other_item_id"], ["other_items.id"]), - UniqueConstraint("other_item_id"), - ) - Table( - "other_items", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class OtherItems(Base): - __tablename__ = 'other_items' - - id = Column(Integer, primary_key=True) - - simple_items = relationship('SimpleItems', uselist=False, \ -back_populates='other_item') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - other_item_id = Column(ForeignKey('other_items.id'), unique=True) - - other_item = relationship('OtherItems', back_populates='simple_items') - """, - ) - - def test_onetomany_noinflect(self, generator: CodeGenerator) -> None: - Table( - "oglkrogk", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("fehwiuhfiwID", INTEGER), - ForeignKeyConstraint(["fehwiuhfiwID"], ["fehwiuhfiw.id"]), - ) - Table("fehwiuhfiw", generator.metadata, Column("id", INTEGER, primary_key=True)) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class Fehwiuhfiw(Base): - __tablename__ = 'fehwiuhfiw' - - id = Column(Integer, primary_key=True) - - oglkrogk = relationship('Oglkrogk', back_populates='fehwiuhfiw') - - -class Oglkrogk(Base): - __tablename__ = 'oglkrogk' - - id = Column(Integer, primary_key=True) - fehwiuhfiwID = Column(ForeignKey('fehwiuhfiw.id')) - - fehwiuhfiw = relationship('Fehwiuhfiw', back_populates='oglkrogk') - """, - ) - - def test_onetomany_conflicting_column(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id", INTEGER), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("relationship", Text), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer, Text -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id = Column(Integer, primary_key=True) - relationship_ = Column('relationship', Text) - - simple_items = relationship('SimpleItems', back_populates='container') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - container_id = Column(ForeignKey('simple_containers.id')) - - container = relationship('SimpleContainers', back_populates='simple_items') - """, - ) - - def test_onetomany_conflicting_relationship(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("relationship_id", INTEGER), - ForeignKeyConstraint(["relationship_id"], ["relationship.id"]), - ) - Table( - "relationship", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class Relationship(Base): - __tablename__ = 'relationship' - - id = Column(Integer, primary_key=True) - - simple_items = relationship('SimpleItems', back_populates='relationship_') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - relationship_id = Column(ForeignKey('relationship.id')) - - relationship_ = relationship('Relationship', back_populates='simple_items') - """, - ) - - @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) - def test_manytoone_nobidi(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id", INTEGER), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id = Column(Integer, primary_key=True) - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - container_id = Column(ForeignKey('simple_containers.id')) - - container = relationship('SimpleContainers') - """, - ) - - def test_manytomany(self, generator: CodeGenerator) -> None: - Table( - "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - Table( - "container_items", - generator.metadata, - Column("item_id", INTEGER), - Column("container_id", INTEGER), - ForeignKeyConstraint(["item_id"], ["simple_items.id"]), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer, Table -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() -metadata = Base.metadata - - -class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id = Column(Integer, primary_key=True) - - item = relationship('SimpleItems', secondary='container_items', \ -back_populates='container') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - - container = relationship('SimpleContainers', secondary='container_items', \ -back_populates='item') - - -t_container_items = Table( - 'container_items', metadata, - Column('item_id', ForeignKey('simple_items.id')), - Column('container_id', ForeignKey('simple_containers.id')) -) - """, - ) - - @pytest.mark.parametrize("generator", [["nobidi"]], indirect=True) - def test_manytomany_nobidi(self, generator: CodeGenerator) -> None: - Table( - "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - Table( - "container_items", - generator.metadata, - Column("item_id", INTEGER), - Column("container_id", INTEGER), - ForeignKeyConstraint(["item_id"], ["simple_items.id"]), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer, Table -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() -metadata = Base.metadata - - -class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id = Column(Integer, primary_key=True) - - item = relationship('SimpleItems', secondary='container_items') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - - -t_container_items = Table( - 'container_items', metadata, - Column('item_id', ForeignKey('simple_items.id')), - Column('container_id', ForeignKey('simple_containers.id')) -) - """, - ) - - def test_manytomany_selfref(self, generator: CodeGenerator) -> None: - Table( - "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - Table( - "child_items", - generator.metadata, - Column("parent_id", INTEGER), - Column("child_id", INTEGER), - ForeignKeyConstraint(["parent_id"], ["simple_items.id"]), - ForeignKeyConstraint(["child_id"], ["simple_items.id"]), - schema="otherschema", - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer, Table -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() -metadata = Base.metadata - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - - parent = relationship('SimpleItems', secondary='otherschema.child_items', \ -primaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \ -secondaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \ -back_populates='child') - child = relationship('SimpleItems', secondary='otherschema.child_items', \ -primaryjoin=lambda: SimpleItems.id == t_child_items.c.parent_id, \ -secondaryjoin=lambda: SimpleItems.id == t_child_items.c.child_id, \ -back_populates='parent') - - -t_child_items = Table( - 'child_items', metadata, - Column('parent_id', ForeignKey('simple_items.id')), - Column('child_id', ForeignKey('simple_items.id')), - schema='otherschema' -) - """, - ) - - def test_manytomany_composite(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id1", INTEGER, primary_key=True), - Column("id2", INTEGER, primary_key=True), - ) - Table( - "simple_containers", - generator.metadata, - Column("id1", INTEGER, primary_key=True), - Column("id2", INTEGER, primary_key=True), - ) - Table( - "container_items", - generator.metadata, - Column("item_id1", INTEGER), - Column("item_id2", INTEGER), - Column("container_id1", INTEGER), - Column("container_id2", INTEGER), - ForeignKeyConstraint( - ["item_id1", "item_id2"], ["simple_items.id1", "simple_items.id2"] - ), - ForeignKeyConstraint( - ["container_id1", "container_id2"], - ["simple_containers.id1", "simple_containers.id2"], - ), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() -metadata = Base.metadata - - -class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id1 = Column(Integer, primary_key=True, nullable=False) - id2 = Column(Integer, primary_key=True, nullable=False) - - simple_items = relationship('SimpleItems', secondary='container_items', \ -back_populates='simple_containers') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id1 = Column(Integer, primary_key=True, nullable=False) - id2 = Column(Integer, primary_key=True, nullable=False) - - simple_containers = relationship('SimpleContainers', secondary='container_items', \ -back_populates='simple_items') - - -t_container_items = Table( - 'container_items', metadata, - Column('item_id1', Integer), - Column('item_id2', Integer), - Column('container_id1', Integer), - Column('container_id2', Integer), - ForeignKeyConstraint(['container_id1', 'container_id2'], \ -['simple_containers.id1', 'simple_containers.id2']), - ForeignKeyConstraint(['item_id1', 'item_id2'], ['simple_items.id1', \ -'simple_items.id2']) -) - """, - ) - - def test_joined_inheritance(self, generator: CodeGenerator) -> None: - Table( - "simple_sub_items", - generator.metadata, - Column("simple_items_id", INTEGER, primary_key=True), - Column("data3", INTEGER), - ForeignKeyConstraint(["simple_items_id"], ["simple_items.super_item_id"]), - ) - Table( - "simple_super_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("data1", INTEGER), - ) - Table( - "simple_items", - generator.metadata, - Column("super_item_id", INTEGER, primary_key=True), - Column("data2", INTEGER), - ForeignKeyConstraint(["super_item_id"], ["simple_super_items.id"]), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class SimpleSuperItems(Base): - __tablename__ = 'simple_super_items' - - id = Column(Integer, primary_key=True) - data1 = Column(Integer) - - -class SimpleItems(SimpleSuperItems): - __tablename__ = 'simple_items' - - super_item_id = Column(ForeignKey('simple_super_items.id'), primary_key=True) - data2 = Column(Integer) - - -class SimpleSubItems(SimpleItems): - __tablename__ = 'simple_sub_items' - - simple_items_id = Column(ForeignKey('simple_items.super_item_id'), primary_key=True) - data3 = Column(Integer) - """, - ) - - def test_joined_inheritance_same_table_name(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - Table( - "simple", - generator.metadata, - Column("id", INTEGER, ForeignKey("simple.id"), primary_key=True), - schema="altschema", - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, ForeignKey, Integer - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class Simple(Base): - __tablename__ = 'simple' - - id = Column(Integer, primary_key=True) - - - class Simple_(Simple): - __tablename__ = 'simple' - __table_args__ = {'schema': 'altschema'} - - id = Column(ForeignKey('simple.id'), primary_key=True) - """, - ) - - @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) - def test_use_inflect(self, generator: CodeGenerator) -> None: - Table( - "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - - Table("singular", generator.metadata, Column("id", INTEGER, primary_key=True)) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class SimpleItem(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - - -class Singular(Base): - __tablename__ = 'singular' - - id = Column(Integer, primary_key=True) - """, - ) - - @pytest.mark.parametrize("generator", [["use_inflect"]], indirect=True) - @pytest.mark.parametrize( - argnames=("table_name", "class_name", "relationship_name"), - argvalues=[ - ("manufacturers", "manufacturer", "manufacturer"), - ("statuses", "status", "status"), - ("studies", "study", "study"), - ("moose", "moose", "moose"), - ], - ids=[ - "test_inflect_manufacturer", - "test_inflect_status", - "test_inflect_study", - "test_inflect_moose", - ], - ) - def test_use_inflect_plural( - self, - generator: CodeGenerator, - table_name: str, - class_name: str, - relationship_name: str, - ) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column(f"{relationship_name}_id", INTEGER), - ForeignKeyConstraint([f"{relationship_name}_id"], [f"{table_name}.id"]), - UniqueConstraint(f"{relationship_name}_id"), - ) - Table(table_name, generator.metadata, Column("id", INTEGER, primary_key=True)) - - validate_code( - generator.generate(), - f"""\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class {class_name.capitalize()}(Base): - __tablename__ = '{table_name}' - - id = Column(Integer, primary_key=True) - - simple_item = relationship('SimpleItem', uselist=False, \ -back_populates='{relationship_name}') - - -class SimpleItem(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - {relationship_name}_id = Column(ForeignKey('{table_name}.id'), unique=True) - - {relationship_name} = relationship('{class_name.capitalize()}', \ -back_populates='simple_item') - """, - ) - - def test_table_kwargs(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - schema="testschema", - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - __table_args__ = {'schema': 'testschema'} - - id = Column(Integer, primary_key=True) - """, - ) - - def test_table_args_kwargs(self, generator: CodeGenerator) -> None: - simple_items = Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("name", VARCHAR), - schema="testschema", - ) - simple_items.indexes.add( - Index("testidx", simple_items.c.id, simple_items.c.name) - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Index, Integer, String -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - __table_args__ = ( - Index('testidx', 'id', 'name'), - {'schema': 'testschema'} - ) - - id = Column(Integer, primary_key=True) - name = Column(String) - """, - ) - - def test_foreign_key_schema(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("other_item_id", INTEGER), - ForeignKeyConstraint(["other_item_id"], ["otherschema.other_items.id"]), - ) - Table( - "other_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - schema="otherschema", - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, ForeignKey, Integer -from sqlalchemy.orm import declarative_base, relationship - -Base = declarative_base() - - -class OtherItems(Base): - __tablename__ = 'other_items' - __table_args__ = {'schema': 'otherschema'} - - id = Column(Integer, primary_key=True) - - simple_items = relationship('SimpleItems', back_populates='other_item') - - -class SimpleItems(Base): - __tablename__ = 'simple_items' - - id = Column(Integer, primary_key=True) - other_item_id = Column(ForeignKey('otherschema.other_items.id')) - - other_item = relationship('OtherItems', back_populates='simple_items') - """, - ) - - def test_invalid_attribute_names(self, generator: CodeGenerator) -> None: - Table( - "simple-items", - generator.metadata, - Column("id-test", INTEGER, primary_key=True), - Column("4test", INTEGER), - Column("_4test", INTEGER), - Column("def", INTEGER), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class SimpleItems(Base): - __tablename__ = 'simple-items' - - id_test = Column('id-test', Integer, primary_key=True) - _4test = Column('4test', Integer) - _4test_ = Column('_4test', Integer) - def_ = Column('def', Integer) - """, - ) - - def test_pascal(self, generator: CodeGenerator) -> None: - Table( - "CustomerAPIPreference", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class CustomerAPIPreference(Base): - __tablename__ = 'CustomerAPIPreference' - - id = Column(Integer, primary_key=True) - """, - ) - - def test_underscore(self, generator: CodeGenerator) -> None: - Table( - "customer_api_preference", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class CustomerApiPreference(Base): - __tablename__ = 'customer_api_preference' - - id = Column(Integer, primary_key=True) - """, - ) - - def test_pascal_underscore(self, generator: CodeGenerator) -> None: - Table( - "customer_API_Preference", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class CustomerAPIPreference(Base): - __tablename__ = 'customer_API_Preference' - - id = Column(Integer, primary_key=True) - """, - ) - - def test_pascal_multiple_underscore(self, generator: CodeGenerator) -> None: - Table( - "customer_API__Preference", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - - -class CustomerAPIPreference(Base): - __tablename__ = 'customer_API__Preference' - - id = Column(Integer, primary_key=True) - """, - ) - - @pytest.mark.parametrize( - "generator, nocomments", - [([], False), (["nocomments"], True)], - indirect=["generator"], - ) - def test_column_comment(self, generator: CodeGenerator, nocomments: bool) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True, comment="this is a 'comment'"), - ) - - comment_part = "" if nocomments else ", comment=\"this is a 'comment'\"" - validate_code( - generator.generate(), - f"""\ - from sqlalchemy import Column, Integer - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class Simple(Base): - __tablename__ = 'simple' - - id = Column(Integer, primary_key=True{comment_part}) - """, - ) - - def test_table_comment(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - comment="this is a 'comment'", - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class Simple(Base): - __tablename__ = 'simple' - __table_args__ = {'comment': "this is a 'comment'"} - - id = Column(Integer, primary_key=True) - """, - ) - - def test_metadata_column(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("metadata", VARCHAR), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, String - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class Simple(Base): - __tablename__ = 'simple' - - id = Column(Integer, primary_key=True) - metadata_ = Column('metadata', String) - """, - ) - - def test_invalid_variable_name_from_column(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column(" id ", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class Simple(Base): - __tablename__ = 'simple' - - id = Column(' id ', Integer, primary_key=True) - """, - ) - - def test_only_tables(self, generator: CodeGenerator) -> None: - Table("simple", generator.metadata, Column("id", INTEGER)) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, MetaData, Table - - metadata = MetaData() - - - t_simple = Table( - 'simple', metadata, - Column('id', Integer) - ) - """, - ) - - def test_named_constraints(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER), - Column("text", VARCHAR), - CheckConstraint("id > 0", name="checktest"), - PrimaryKeyConstraint("id", name="primarytest"), - UniqueConstraint("text", name="uniquetest"), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import CheckConstraint, Column, Integer, \ -PrimaryKeyConstraint, String, UniqueConstraint - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class Simple(Base): - __tablename__ = 'simple' - __table_args__ = ( - CheckConstraint('id > 0', name='checktest'), - PrimaryKeyConstraint('id', name='primarytest'), - UniqueConstraint('text', name='uniquetest') - ) - - id = Column(Integer) - text = Column(String) - """, - ) - - def test_named_foreign_key_constraints(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id", INTEGER), - ForeignKeyConstraint( - ["container_id"], ["simple_containers.id"], name="foreignkeytest" - ), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, ForeignKeyConstraint, Integer - from sqlalchemy.orm import declarative_base, relationship - - Base = declarative_base() - - - class SimpleContainers(Base): - __tablename__ = 'simple_containers' - - id = Column(Integer, primary_key=True) - - simple_items = relationship('SimpleItems', back_populates='container') - - - class SimpleItems(Base): - __tablename__ = 'simple_items' - __table_args__ = ( - ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ -name='foreignkeytest'), - ) - - id = Column(Integer, primary_key=True) - container_id = Column(Integer) - - container = relationship('SimpleContainers', \ -back_populates='simple_items') - """, - ) - - # @pytest.mark.xfail(strict=True) - def test_colname_import_conflict(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("text", VARCHAR), - Column("textwithdefault", VARCHAR, server_default=text("'test'")), - ) - - validate_code( - generator.generate(), - """\ - from sqlalchemy import Column, Integer, String, text - from sqlalchemy.orm import declarative_base - - Base = declarative_base() - - - class Simple(Base): - __tablename__ = 'simple' - - id = Column(Integer, primary_key=True) - text_ = Column('text', String) - textwithdefault = Column(String, server_default=text("'test'")) - """, - ) - - -class TestDataclassGenerator: - @pytest.fixture - def generator( - self, request: FixtureRequest, metadata: MetaData, engine: Engine - ) -> CodeGenerator: - options = getattr(request, "param", []) - return DataclassGenerator(metadata, engine, options) - - def test_basic_class(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("name", VARCHAR(20)), - ) - - validate_code( - generator.generate(), - """\ - from __future__ import annotations - - from dataclasses import dataclass, field - from typing import Optional - - from sqlalchemy import Column, Integer, String - from sqlalchemy.orm import registry - - mapper_registry = registry() - - - @mapper_registry.mapped - @dataclass - class Simple: - __tablename__ = 'simple' - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - name: Optional[str] = field(default=None, metadata={'sa': \ -Column(String(20))}) - """, - ) - - def test_mandatory_field_last(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("name", VARCHAR(20), server_default=text("foo")), - Column("age", INTEGER, nullable=False), - ) - - validate_code( - generator.generate(), - """\ - from __future__ import annotations - - from dataclasses import dataclass, field - from typing import Optional - - from sqlalchemy import Column, Integer, String, text - from sqlalchemy.orm import registry - - mapper_registry = registry() - - - @mapper_registry.mapped - @dataclass - class Simple: - __tablename__ = 'simple' - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - age: int = field(metadata={'sa': Column(Integer, nullable=False)}) - name: Optional[str] = field(default=None, metadata={'sa': \ -Column(String(20), server_default=text('foo'))}) - """, - ) - - def test_onetomany_optional(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id", INTEGER), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from __future__ import annotations - - from dataclasses import dataclass, field - from typing import List, Optional - - from sqlalchemy import Column, ForeignKey, Integer - from sqlalchemy.orm import registry, relationship - - mapper_registry = registry() - - - @mapper_registry.mapped - @dataclass - class SimpleContainers: - __tablename__ = 'simple_containers' - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - - simple_items: List[SimpleItems] = field(default_factory=list, \ -metadata={'sa': relationship('SimpleItems', back_populates='container')}) - - - @mapper_registry.mapped - @dataclass - class SimpleItems: - __tablename__ = 'simple_items' - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - container_id: Optional[int] = field(default=None, \ -metadata={'sa': Column(ForeignKey('simple_containers.id'))}) - - container: Optional[SimpleContainers] = field(default=None, \ -metadata={'sa': relationship('SimpleContainers', back_populates='simple_items')}) - """, - ) - - def test_manytomany(self, generator: CodeGenerator) -> None: - Table( - "simple_items", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - Table( - "container_items", - generator.metadata, - Column("item_id", INTEGER), - Column("container_id", INTEGER), - ForeignKeyConstraint(["item_id"], ["simple_items.id"]), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - - validate_code( - generator.generate(), - """\ - from __future__ import annotations - - from dataclasses import dataclass, field - from typing import List - - from sqlalchemy import Column, ForeignKey, Integer, Table - from sqlalchemy.orm import registry, relationship - - mapper_registry = registry() - metadata = mapper_registry.metadata - - - @mapper_registry.mapped - @dataclass - class SimpleContainers: - __tablename__ = 'simple_containers' - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - - item: List[SimpleItems] = field(default_factory=list, metadata=\ -{'sa': relationship('SimpleItems', secondary='container_items', \ -back_populates='container')}) - - - @mapper_registry.mapped - @dataclass - class SimpleItems: - __tablename__ = 'simple_items' - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - - container: List[SimpleContainers] = \ -field(default_factory=list, metadata={'sa': relationship('SimpleContainers', \ -secondary='container_items', back_populates='item')}) - - - t_container_items = Table( - 'container_items', metadata, - Column('item_id', ForeignKey('simple_items.id')), - Column('container_id', ForeignKey('simple_containers.id')) - ) - """, - ) - - def test_named_foreign_key_constraints(self, generator: CodeGenerator) -> None: - Table( - "simple_items", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id", INTEGER), - ForeignKeyConstraint( - ["container_id"], ["simple_containers.id"], name="foreignkeytest" - ), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from __future__ import annotations - - from dataclasses import dataclass, field - from typing import List, Optional - - from sqlalchemy import Column, ForeignKeyConstraint, Integer - from sqlalchemy.orm import registry, relationship - - mapper_registry = registry() - - - @mapper_registry.mapped - @dataclass - class SimpleContainers: - __tablename__ = 'simple_containers' - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - - simple_items: List[SimpleItems] = field(default_factory=list, \ -metadata={'sa': relationship('SimpleItems', back_populates='container')}) - - - @mapper_registry.mapped - @dataclass - class SimpleItems: - __tablename__ = 'simple_items' - __table_args__ = ( - ForeignKeyConstraint(['container_id'], ['simple_containers.id'], \ -name='foreignkeytest'), - ) - __sa_dataclass_metadata_key__ = 'sa' - - id: int = field(init=False, metadata={'sa': Column(Integer, \ -primary_key=True)}) - container_id: Optional[int] = field(default=None, metadata={'sa': \ -Column(Integer)}) - - container: Optional[SimpleContainers] = field(default=None, \ -metadata={'sa': relationship('SimpleContainers', back_populates='simple_items')}) - """, - ) - - def test_uuid_type_annotation(self, generator: CodeGenerator) -> None: - Table( - "simple", - generator.metadata, - Column("id", UUID, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from __future__ import annotations - - from dataclasses import dataclass, field - - from sqlalchemy import Column - from sqlalchemy.dialects.postgresql import UUID - from sqlalchemy.orm import registry - - mapper_registry = registry() - - - @mapper_registry.mapped - @dataclass - class Simple: - __tablename__ = 'simple' - __sa_dataclass_metadata_key__ = 'sa' - - id: str = field(init=False, metadata={'sa': \ -Column(UUID, primary_key=True)}) - """, - ) - - -class TestSQLModelGenerator: - @pytest.fixture - def generator( - self, request: FixtureRequest, metadata: MetaData, engine: Engine - ) -> CodeGenerator: - options = getattr(request, "param", []) - return SQLModelGenerator(metadata, engine, options) - - def test_indexes(self, generator: CodeGenerator) -> None: - simple_items = Table( - "item", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("number", INTEGER), - Column("text", VARCHAR), - ) - simple_items.indexes.add(Index("idx_number", simple_items.c.number)) - simple_items.indexes.add( - Index("idx_text_number", simple_items.c.text, simple_items.c.number) - ) - simple_items.indexes.add(Index("idx_text", simple_items.c.text, unique=True)) - - validate_code( - generator.generate(), - """\ - from typing import Optional - - from sqlalchemy import Column, Index, Integer, String - from sqlmodel import Field, SQLModel - - class Item(SQLModel, table=True): - __table_args__ = ( - Index('idx_number', 'number'), - Index('idx_text', 'text', unique=True), - Index('idx_text_number', 'text', 'number') - ) - - id: Optional[int] = Field(default=None, sa_column=Column(\ -'id', Integer, primary_key=True)) - number: Optional[int] = Field(default=None, sa_column=Column(\ -'number', Integer)) - text: Optional[str] = Field(default=None, sa_column=Column(\ -'text', String)) - """, - ) - - def test_constraints(self, generator: CodeGenerator) -> None: - Table( - "simple_constraints", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("number", INTEGER), - CheckConstraint("number > 0"), - UniqueConstraint("id", "number"), - ) - - validate_code( - generator.generate(), - """\ - from typing import Optional - - from sqlalchemy import CheckConstraint, Column, Integer, UniqueConstraint - from sqlmodel import Field, SQLModel - - class SimpleConstraints(SQLModel, table=True): - __tablename__ = 'simple_constraints' - __table_args__ = ( - CheckConstraint('number > 0'), - UniqueConstraint('id', 'number') - ) - - id: Optional[int] = Field(default=None, sa_column=Column(\ -'id', Integer, primary_key=True)) - number: Optional[int] = Field(default=None, sa_column=Column(\ -'number', Integer)) - """, - ) - - def test_onetomany(self, generator: CodeGenerator) -> None: - Table( - "simple_goods", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("container_id", INTEGER), - ForeignKeyConstraint(["container_id"], ["simple_containers.id"]), - ) - Table( - "simple_containers", - generator.metadata, - Column("id", INTEGER, primary_key=True), - ) - - validate_code( - generator.generate(), - """\ - from typing import List, Optional - - from sqlalchemy import Column, ForeignKey, Integer - from sqlmodel import Field, Relationship, SQLModel - - class SimpleContainers(SQLModel, table=True): - __tablename__ = 'simple_containers' - - id: Optional[int] = Field(default=None, sa_column=Column(\ -'id', Integer, primary_key=True)) - - simple_goods: List['SimpleGoods'] = Relationship(\ -back_populates='container') - - - class SimpleGoods(SQLModel, table=True): - __tablename__ = 'simple_goods' - - id: Optional[int] = Field(default=None, sa_column=Column(\ -'id', Integer, primary_key=True)) - container_id: Optional[int] = Field(default=None, sa_column=Column(\ -'container_id', ForeignKey('simple_containers.id'))) - - container: Optional['SimpleContainers'] = Relationship(\ -back_populates='simple_goods') - """, - ) - - def test_onetoone(self, generator: CodeGenerator) -> None: - Table( - "simple_onetoone", - generator.metadata, - Column("id", INTEGER, primary_key=True), - Column("other_item_id", INTEGER), - ForeignKeyConstraint(["other_item_id"], ["other_items.id"]), - UniqueConstraint("other_item_id"), - ) - Table( - "other_items", generator.metadata, Column("id", INTEGER, primary_key=True) - ) - - validate_code( - generator.generate(), - """\ - from typing import Optional - - from sqlalchemy import Column, ForeignKey, Integer - from sqlmodel import Field, Relationship, SQLModel - - class OtherItems(SQLModel, table=True): - __tablename__ = 'other_items' - - id: Optional[int] = Field(default=None, sa_column=Column(\ -'id', Integer, primary_key=True)) - - simple_onetoone: Optional['SimpleOnetoone'] = Relationship(\ -sa_relationship_kwargs={'uselist': False}, back_populates='other_item') - - - class SimpleOnetoone(SQLModel, table=True): - __tablename__ = 'simple_onetoone' - - id: Optional[int] = Field(default=None, sa_column=Column(\ -'id', Integer, primary_key=True)) - other_item_id: Optional[int] = Field(default=None, sa_column=Column(\ -'other_item_id', ForeignKey('other_items.id'), unique=True)) - - other_item: Optional['OtherItems'] = Relationship(\ -back_populates='simple_onetoone') - """, - )