diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4745e50..7476a26 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: - id: mypy additional_dependencies: - pytest - - "sqlalchemy[mypy] < 2.0" + - "SQLAlchemy >= 2.0.29" - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 diff --git a/pyproject.toml b/pyproject.toml index c643426..35579d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ - "SQLAlchemy >= 2.0.23", + "SQLAlchemy >= 2.0.29", "inflect >= 4.0.0", "importlib_metadata; python_version < '3.10'", ] diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 21eadb6..f155908 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -38,7 +38,7 @@ Text, UniqueConstraint, ) -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import DOMAIN, JSONB from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import CompileError from sqlalchemy.sql.elements import TextClause @@ -222,6 +222,12 @@ def collect_imports_for_column(self, column: Column[Any]) -> None: or column.type.astext_type.length is not None ): self.add_import(column.type.astext_type) + elif isinstance(column.type, DOMAIN): + self.add_import(column.type.data_type.__class__) + if isinstance(column.type.default, TextClause) or isinstance( + column.type.check, TextClause + ): + self.add_literal_import("sqlalchemy", "text") if column.default: self.add_import(column.default) @@ -534,6 +540,12 @@ def render_column_type(self, coltype: object) -> str: ): del kwargs["astext_type"] + if isinstance(coltype, DOMAIN): + if isinstance(coltype.default, TextClause): + kwargs["default"] = render_callable("text", repr(coltype.default.text)) + if isinstance(coltype.check, TextClause): + kwargs["check"] = render_callable("text", repr(coltype.check.text)) + if args or kwargs: return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs) else: diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py index bf6ff4e..619a1f9 100644 --- a/tests/test_generator_tables.py +++ b/tests/test_generator_tables.py @@ -188,6 +188,43 @@ def test_enum_detection(generator: CodeGenerator) -> None: ) +@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +def test_domain(generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column( + "postal_code", + postgresql.DOMAIN( + "us_postal_code", + Text, + constraint_name="valid_us_postal_code", + not_null=False, + check=text("VALUE ~ '^\\d{5}$' OR VALUE ~ '^\\d{5}-\\d{4}$'"), + ), + nullable=False, + ), + ) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, MetaData, Table, Text, text + from sqlalchemy.dialects.postgresql import DOMAIN + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('postal_code', DOMAIN('us_postal_code', Text(), \ +constraint_name='valid_us_postal_code', not_null=False, \ +check=text("VALUE ~ '^\\\\d{5}$' OR VALUE ~ '^\\\\d{5}-\\\\d{4}$'")), nullable=False) + ) + """, + ) + + @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) def test_column_adaptation(generator: CodeGenerator) -> None: Table(