Skip to content

Commit

Permalink
refactor(generator): Enable use of generators without engine
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlek committed Apr 23, 2024
1 parent 2a60532 commit 164bbcd
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 56 deletions.
2 changes: 1 addition & 1 deletion src/sqlacodegen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main() -> None:

# Instantiate the generator
generator_class = generators[args.generator].load()
generator = generator_class(metadata, engine, options)
generator = generator_class(metadata, engine.dialect, options)

# Open the target file (if given)
with ExitStack() as stack:
Expand Down
32 changes: 16 additions & 16 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Computed,
Constraint,
DefaultClause,
Dialect,
Enum,
Float,
ForeignKey,
Expand All @@ -39,7 +40,6 @@
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import TextClause

Expand Down Expand Up @@ -95,10 +95,10 @@ class CodeGenerator(metaclass=ABCMeta):
valid_options: ClassVar[set[str]] = set()

def __init__(
self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
self, metadata: MetaData, dialect: Dialect, options: Sequence[str]
):
self.metadata: MetaData = metadata
self.bind: Connection | Engine = bind
self.dialect: Dialect = dialect
self.options: set[str] = set(options)

# Validate options
Expand All @@ -124,12 +124,12 @@ class TablesGenerator(CodeGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
):
super().__init__(metadata, bind, options)
super().__init__(metadata, dialect, options)
self.indentation: str = indentation
self.imports: dict[str, set[str]] = defaultdict(set)
self.module_imports: set[str] = set()
Expand Down Expand Up @@ -562,7 +562,7 @@ def add_fk_options(*opts: Any) -> None:
]
add_fk_options(local_columns, remote_columns)
elif isinstance(constraint, CheckConstraint):
args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
args.append(repr(get_compiled_expression(constraint.sqltext, self.dialect)))
elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
args.extend(repr(col.name) for col in constraint.columns)
else:
Expand Down Expand Up @@ -608,7 +608,7 @@ def fix_column_types(self, table: Table) -> None:
# Detect check constraints for boolean and enum columns
for constraint in table.constraints.copy():
if isinstance(constraint, CheckConstraint):
sqltext = get_compiled_expression(constraint.sqltext, self.bind)
sqltext = get_compiled_expression(constraint.sqltext, self.dialect)

# Turn any integer-like column with a CheckConstraint like
# "column IN (0, 1)" into a Boolean
Expand Down Expand Up @@ -646,7 +646,7 @@ def fix_column_types(self, table: Table) -> None:
pass

# PostgreSQL specific fix: detect sequences from server_default
if column.server_default and self.bind.dialect.name == "postgresql":
if column.server_default and self.dialect.name == "postgresql":
if isinstance(column.server_default, DefaultClause) and isinstance(
column.server_default.arg, TextClause
):
Expand All @@ -661,7 +661,7 @@ def fix_column_types(self, table: Table) -> None:
column.server_default = None

def get_adapted_type(self, coltype: Any) -> Any:
compiled_type = coltype.compile(self.bind.engine.dialect)
compiled_type = coltype.compile(self.dialect)
for supercls in coltype.__class__.__mro__:
if not supercls.__name__.startswith("_") and hasattr(
supercls, "__visit_name__"
Expand All @@ -687,7 +687,7 @@ def get_adapted_type(self, coltype: Any) -> Any:
try:
# If the adapted column type does not render the same as the
# original, don't substitute it
if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
if new_coltype.compile(self.dialect) != compiled_type:
# Make an exception to the rule for Float and arrays of Float,
# since at least on PostgreSQL, Float can accurately represent
# both REAL and DOUBLE_PRECISION
Expand Down Expand Up @@ -718,13 +718,13 @@ class DeclarativeGenerator(TablesGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
base_class_name: str = "Base",
):
super().__init__(metadata, bind, options, indentation=indentation)
super().__init__(metadata, dialect, options, indentation=indentation)
self.base_class_name: str = base_class_name
self.inflect_engine = inflect.engine()

Expand Down Expand Up @@ -1305,7 +1305,7 @@ class DataclassGenerator(DeclarativeGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
Expand All @@ -1315,7 +1315,7 @@ def __init__(
):
super().__init__(
metadata,
bind,
dialect,
options,
indentation=indentation,
base_class_name=base_class_name,
Expand Down Expand Up @@ -1344,15 +1344,15 @@ class SQLModelGenerator(DeclarativeGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
dialect: Dialect,
options: Sequence[str],
*,
indentation: str = " ",
base_class_name: str = "SQLModel",
):
super().__init__(
metadata,
bind,
dialect,
options,
indentation=indentation,
base_class_name=base_class_name,
Expand Down
7 changes: 3 additions & 4 deletions src/sqlacodegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from collections.abc import Mapping
from typing import Any

from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
from sqlalchemy.engine import Connection, Engine
from sqlalchemy import Dialect, PrimaryKeyConstraint, UniqueConstraint
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import (
Expand Down Expand Up @@ -34,9 +33,9 @@ def get_constraint_sort_key(constraint: Constraint) -> str:
return str(constraint)


def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str:
def get_compiled_expression(statement: ClauseElement, dialect: Dialect) -> str:
"""Return the statement in a form where any placeholders have been filled in."""
return str(statement.compile(bind, compile_kwargs={"literal_binds": True}))
return str(statement.compile(dialect=dialect, compile_kwargs={"literal_binds": True}))


def get_common_fk_constraints(
Expand Down
17 changes: 9 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

import pytest
from pytest import FixtureRequest
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy import Dialect
from sqlalchemy.dialects import mysql, postgresql, sqlite
from sqlalchemy.orm import clear_mappers, configure_mappers
from sqlalchemy.schema import MetaData


@pytest.fixture
def engine(request: FixtureRequest) -> Engine:
dialect = getattr(request, "param", None)
if dialect == "postgresql":
return create_engine("postgresql:///testdb")
elif dialect == "mysql":
return create_engine("mysql+mysqlconnector://testdb")
def dialect(request: FixtureRequest) -> Dialect:
dialect_name = getattr(request, "param", None)
if dialect_name == "postgresql":
return postgresql.dialect()
elif dialect_name == "mysql":
return mysql.mysqlconnector.dialect()
else:
return create_engine("sqlite:///:memory:")
return sqlite.dialect()


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tests/test_generator_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import Dialect
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine import Engine
from sqlalchemy.schema import Column, ForeignKeyConstraint, MetaData, Table
from sqlalchemy.sql.expression import text
from sqlalchemy.types import INTEGER, VARCHAR
Expand All @@ -15,10 +15,10 @@

@pytest.fixture
def generator(
request: FixtureRequest, metadata: MetaData, engine: Engine
request: FixtureRequest, metadata: MetaData, dialect: Dialect
) -> CodeGenerator:
options = getattr(request, "param", [])
return DataclassGenerator(metadata, engine, options)
return DataclassGenerator(metadata, dialect, options)


def test_basic_class(generator: CodeGenerator) -> None:
Expand Down
7 changes: 3 additions & 4 deletions tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy.engine import Engine
from sqlalchemy import Dialect, PrimaryKeyConstraint
from sqlalchemy.schema import (
CheckConstraint,
Column,
Expand All @@ -24,10 +23,10 @@

@pytest.fixture
def generator(
request: FixtureRequest, metadata: MetaData, engine: Engine
request: FixtureRequest, metadata: MetaData, dialect: Dialect
) -> CodeGenerator:
options = getattr(request, "param", [])
return DeclarativeGenerator(metadata, engine, options)
return DeclarativeGenerator(metadata, dialect, options)


def test_indexes(generator: CodeGenerator) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_generator_sqlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy.engine import Engine
from sqlalchemy import Dialect
from sqlalchemy.schema import (
CheckConstraint,
Column,
Expand All @@ -21,10 +21,10 @@

@pytest.fixture
def generator(
request: FixtureRequest, metadata: MetaData, engine: Engine
request: FixtureRequest, metadata: MetaData, dialect: Dialect
) -> CodeGenerator:
options = getattr(request, "param", [])
return SQLModelGenerator(metadata, engine, options)
return SQLModelGenerator(metadata, dialect, options)


def test_indexes(generator: CodeGenerator) -> None:
Expand Down
Loading

0 comments on commit 164bbcd

Please sign in to comment.