From 164bbcd78345a2bc5c8c179d56dd5328444d9024 Mon Sep 17 00:00:00 2001 From: Thomas Berdy Date: Tue, 23 Apr 2024 22:14:12 +0200 Subject: [PATCH] refactor(generator): Enable use of generators without engine --- src/sqlacodegen/cli.py | 2 +- src/sqlacodegen/generators.py | 32 +++++++++++++-------------- src/sqlacodegen/utils.py | 7 +++--- tests/conftest.py | 17 ++++++++------- tests/test_generator_dataclass.py | 6 ++--- tests/test_generator_declarative.py | 7 +++--- tests/test_generator_sqlmodel.py | 6 ++--- tests/test_generator_tables.py | 34 ++++++++++++++--------------- 8 files changed, 55 insertions(+), 56 deletions(-) diff --git a/src/sqlacodegen/cli.py b/src/sqlacodegen/cli.py index 755f542b..58549537 100644 --- a/src/sqlacodegen/cli.py +++ b/src/sqlacodegen/cli.py @@ -86,7 +86,7 @@ def main() -> None: # Instantiate the generator generator_class = generators[args.generator].load() - generator = generator_class(metadata, engine, options) + generator = generator_class(metadata, engine.dialect, options) # Open the target file (if given) with ExitStack() as stack: diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 21eadb63..8a46b2a3 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -25,6 +25,7 @@ Computed, Constraint, DefaultClause, + Dialect, Enum, Float, ForeignKey, @@ -39,7 +40,6 @@ UniqueConstraint, ) from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import CompileError from sqlalchemy.sql.elements import TextClause @@ -95,10 +95,10 @@ class CodeGenerator(metaclass=ABCMeta): valid_options: ClassVar[set[str]] = set() def __init__( - self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str] + self, metadata: MetaData, dialect: Dialect, options: Sequence[str] ): self.metadata: MetaData = metadata - self.bind: Connection | Engine = bind + self.dialect: Dialect = dialect self.options: set[str] = set(options) # Validate options @@ -124,12 +124,12 @@ class TablesGenerator(CodeGenerator): def __init__( self, metadata: MetaData, - bind: Connection | Engine, + dialect: Dialect, options: Sequence[str], *, indentation: str = " ", ): - super().__init__(metadata, bind, options) + super().__init__(metadata, dialect, options) self.indentation: str = indentation self.imports: dict[str, set[str]] = defaultdict(set) self.module_imports: set[str] = set() @@ -562,7 +562,7 @@ def add_fk_options(*opts: Any) -> None: ] add_fk_options(local_columns, remote_columns) elif isinstance(constraint, CheckConstraint): - args.append(repr(get_compiled_expression(constraint.sqltext, self.bind))) + args.append(repr(get_compiled_expression(constraint.sqltext, self.dialect))) elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)): args.extend(repr(col.name) for col in constraint.columns) else: @@ -608,7 +608,7 @@ def fix_column_types(self, table: Table) -> None: # Detect check constraints for boolean and enum columns for constraint in table.constraints.copy(): if isinstance(constraint, CheckConstraint): - sqltext = get_compiled_expression(constraint.sqltext, self.bind) + sqltext = get_compiled_expression(constraint.sqltext, self.dialect) # Turn any integer-like column with a CheckConstraint like # "column IN (0, 1)" into a Boolean @@ -646,7 +646,7 @@ def fix_column_types(self, table: Table) -> None: pass # PostgreSQL specific fix: detect sequences from server_default - if column.server_default and self.bind.dialect.name == "postgresql": + if column.server_default and self.dialect.name == "postgresql": if isinstance(column.server_default, DefaultClause) and isinstance( column.server_default.arg, TextClause ): @@ -661,7 +661,7 @@ def fix_column_types(self, table: Table) -> None: column.server_default = None def get_adapted_type(self, coltype: Any) -> Any: - compiled_type = coltype.compile(self.bind.engine.dialect) + compiled_type = coltype.compile(self.dialect) for supercls in coltype.__class__.__mro__: if not supercls.__name__.startswith("_") and hasattr( supercls, "__visit_name__" @@ -687,7 +687,7 @@ def get_adapted_type(self, coltype: Any) -> Any: try: # If the adapted column type does not render the same as the # original, don't substitute it - if new_coltype.compile(self.bind.engine.dialect) != compiled_type: + if new_coltype.compile(self.dialect) != compiled_type: # Make an exception to the rule for Float and arrays of Float, # since at least on PostgreSQL, Float can accurately represent # both REAL and DOUBLE_PRECISION @@ -718,13 +718,13 @@ class DeclarativeGenerator(TablesGenerator): def __init__( self, metadata: MetaData, - bind: Connection | Engine, + dialect: Dialect, options: Sequence[str], *, indentation: str = " ", base_class_name: str = "Base", ): - super().__init__(metadata, bind, options, indentation=indentation) + super().__init__(metadata, dialect, options, indentation=indentation) self.base_class_name: str = base_class_name self.inflect_engine = inflect.engine() @@ -1305,7 +1305,7 @@ class DataclassGenerator(DeclarativeGenerator): def __init__( self, metadata: MetaData, - bind: Connection | Engine, + dialect: Dialect, options: Sequence[str], *, indentation: str = " ", @@ -1315,7 +1315,7 @@ def __init__( ): super().__init__( metadata, - bind, + dialect, options, indentation=indentation, base_class_name=base_class_name, @@ -1344,7 +1344,7 @@ class SQLModelGenerator(DeclarativeGenerator): def __init__( self, metadata: MetaData, - bind: Connection | Engine, + dialect: Dialect, options: Sequence[str], *, indentation: str = " ", @@ -1352,7 +1352,7 @@ def __init__( ): super().__init__( metadata, - bind, + dialect, options, indentation=indentation, base_class_name=base_class_name, diff --git a/src/sqlacodegen/utils.py b/src/sqlacodegen/utils.py index 397f89e0..0ac9efc0 100644 --- a/src/sqlacodegen/utils.py +++ b/src/sqlacodegen/utils.py @@ -4,8 +4,7 @@ from collections.abc import Mapping from typing import Any -from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint -from sqlalchemy.engine import Connection, Engine +from sqlalchemy import Dialect, PrimaryKeyConstraint, UniqueConstraint from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import ( @@ -34,9 +33,9 @@ def get_constraint_sort_key(constraint: Constraint) -> str: return str(constraint) -def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str: +def get_compiled_expression(statement: ClauseElement, dialect: Dialect) -> str: """Return the statement in a form where any placeholders have been filled in.""" - return str(statement.compile(bind, compile_kwargs={"literal_binds": True})) + return str(statement.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) def get_common_fk_constraints( diff --git a/tests/conftest.py b/tests/conftest.py index 022e786c..5b843e3f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,20 +2,21 @@ import pytest from pytest import FixtureRequest -from sqlalchemy.engine import Engine, create_engine +from sqlalchemy import Dialect +from sqlalchemy.dialects import mysql, postgresql, sqlite from sqlalchemy.orm import clear_mappers, configure_mappers from sqlalchemy.schema import MetaData @pytest.fixture -def engine(request: FixtureRequest) -> Engine: - dialect = getattr(request, "param", None) - if dialect == "postgresql": - return create_engine("postgresql:///testdb") - elif dialect == "mysql": - return create_engine("mysql+mysqlconnector://testdb") +def dialect(request: FixtureRequest) -> Dialect: + dialect_name = getattr(request, "param", None) + if dialect_name == "postgresql": + return postgresql.dialect() + elif dialect_name == "mysql": + return mysql.mysqlconnector.dialect() else: - return create_engine("sqlite:///:memory:") + return sqlite.dialect() @pytest.fixture diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py index ae7eab25..e678d370 100644 --- a/tests/test_generator_dataclass.py +++ b/tests/test_generator_dataclass.py @@ -2,8 +2,8 @@ import pytest from _pytest.fixtures import FixtureRequest +from sqlalchemy import Dialect from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.engine import Engine from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table from sqlalchemy.sql.expression import text from sqlalchemy.types import INTEGER, VARCHAR @@ -15,10 +15,10 @@ @pytest.fixture def generator( - request: FixtureRequest, metadata: MetaData, engine: Engine + request: FixtureRequest, metadata: MetaData, dialect: Dialect ) -> CodeGenerator: options = getattr(request, "param", []) - return DataclassGenerator(metadata, engine, options) + return DataclassGenerator(metadata, dialect, options) def test_basic_class(generator: CodeGenerator) -> None: diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index d9bf7b53..a8b71f87 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -2,8 +2,7 @@ import pytest from _pytest.fixtures import FixtureRequest -from sqlalchemy import PrimaryKeyConstraint -from sqlalchemy.engine import Engine +from sqlalchemy import Dialect, PrimaryKeyConstraint from sqlalchemy.schema import ( CheckConstraint, Column, @@ -24,10 +23,10 @@ @pytest.fixture def generator( - request: FixtureRequest, metadata: MetaData, engine: Engine + request: FixtureRequest, metadata: MetaData, dialect: Dialect ) -> CodeGenerator: options = getattr(request, "param", []) - return DeclarativeGenerator(metadata, engine, options) + return DeclarativeGenerator(metadata, dialect, options) def test_indexes(generator: CodeGenerator) -> None: diff --git a/tests/test_generator_sqlmodel.py b/tests/test_generator_sqlmodel.py index baf92dd6..ef235c3d 100644 --- a/tests/test_generator_sqlmodel.py +++ b/tests/test_generator_sqlmodel.py @@ -2,7 +2,7 @@ import pytest from _pytest.fixtures import FixtureRequest -from sqlalchemy.engine import Engine +from sqlalchemy import Dialect from sqlalchemy.schema import ( CheckConstraint, Column, @@ -21,10 +21,10 @@ @pytest.fixture def generator( - request: FixtureRequest, metadata: MetaData, engine: Engine + request: FixtureRequest, metadata: MetaData, dialect: Dialect ) -> CodeGenerator: options = getattr(request, "param", []) - return SQLModelGenerator(metadata, engine, options) + return SQLModelGenerator(metadata, dialect, options) def test_indexes(generator: CodeGenerator) -> None: diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py index bf6ff4ee..ae64aec7 100644 --- a/tests/test_generator_tables.py +++ b/tests/test_generator_tables.py @@ -4,8 +4,8 @@ import pytest from _pytest.fixtures import FixtureRequest +from sqlalchemy import Dialect from sqlalchemy.dialects import mysql, postgresql -from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, Column, @@ -28,13 +28,13 @@ @pytest.fixture def generator( - request: FixtureRequest, metadata: MetaData, engine: Engine + request: FixtureRequest, metadata: MetaData, dialect: Dialect ) -> CodeGenerator: options = getattr(request, "param", []) - return TablesGenerator(metadata, engine, options) + return TablesGenerator(metadata, dialect, options) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_fancy_coltypes(generator: CodeGenerator) -> None: Table( "simple_items", @@ -92,7 +92,7 @@ def test_boolean_detection(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_arrays(generator: CodeGenerator) -> None: Table( "simple_items", @@ -118,7 +118,7 @@ def test_arrays(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_jsonb(generator: CodeGenerator) -> None: Table( "simple_items", @@ -143,7 +143,7 @@ def test_jsonb(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_jsonb_default(generator: CodeGenerator) -> None: Table("simple_items", generator.metadata, Column("jsonb", postgresql.JSONB)) @@ -188,7 +188,7 @@ def test_enum_detection(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_column_adaptation(generator: CodeGenerator) -> None: Table( "simple_items", @@ -214,7 +214,7 @@ def test_column_adaptation(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["mysql"], indirect=["dialect"]) def test_mysql_column_types(generator: CodeGenerator) -> None: Table( "simple_items", @@ -558,7 +558,7 @@ def test_pk_default(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["mysql"], indirect=["dialect"]) def test_mysql_timestamp(generator: CodeGenerator) -> None: Table( "simple", @@ -584,7 +584,7 @@ def test_mysql_timestamp(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["mysql"], indirect=["dialect"]) def test_mysql_integer_display_width(generator: CodeGenerator) -> None: Table( "simple_items", @@ -611,7 +611,7 @@ def test_mysql_integer_display_width(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["mysql"], indirect=["dialect"]) def test_mysql_tinytext(generator: CodeGenerator) -> None: Table( "simple_items", @@ -638,7 +638,7 @@ def test_mysql_tinytext(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["mysql"], indirect=["dialect"]) def test_mysql_mediumtext(generator: CodeGenerator) -> None: Table( "simple_items", @@ -665,7 +665,7 @@ def test_mysql_mediumtext(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["mysql"], indirect=["dialect"]) def test_mysql_longtext(generator: CodeGenerator) -> None: Table( "simple_items", @@ -877,7 +877,7 @@ def test_multiline_table_comment(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_postgresql_sequence_standard_name(generator: CodeGenerator) -> None: Table( "simple_items", @@ -906,7 +906,7 @@ def test_postgresql_sequence_standard_name(generator: CodeGenerator) -> None: ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_postgresql_sequence_nonstandard_name(generator: CodeGenerator) -> None: Table( "simple_items", @@ -944,7 +944,7 @@ def test_postgresql_sequence_nonstandard_name(generator: CodeGenerator) -> None: pytest.param('"my.schema"', '"test_seq"'), ], ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +@pytest.mark.parametrize("dialect", ["postgresql"], indirect=["dialect"]) def test_postgresql_sequence_with_schema( generator: CodeGenerator, schemaname: str, seqname: str ) -> None: