Skip to content

Handle TextClause objects in DOMAIN expressions #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 4, 2025
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ repos:
- id: mypy
additional_dependencies:
- pytest
- "sqlalchemy[mypy] < 2.0"
- "SQLAlchemy >= 2.0.29"

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ classifiers = [
]
requires-python = ">=3.9"
dependencies = [
"SQLAlchemy >= 2.0.23",
"SQLAlchemy >= 2.0.29",
"inflect >= 4.0.0",
"importlib_metadata; python_version < '3.10'",
]
@@ -80,6 +80,7 @@ extend-select = [

[tool.mypy]
strict = true
disable_error_code = "no-untyped-call"

[tool.pytest.ini_options]
addopts = "-rsfE --tb=short"
23 changes: 16 additions & 7 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from keyword import iskeyword
from pprint import pformat
from textwrap import indent
from typing import Any, ClassVar
from typing import Any, ClassVar, Literal, cast

import inflect
import sqlalchemy
@@ -38,7 +38,7 @@
TypeDecorator,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import DOMAIN, JSONB
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import TextClause
@@ -228,6 +228,8 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
or column.type.astext_type.length is not None
):
self.add_import(column.type.astext_type)
elif isinstance(column.type, DOMAIN):
self.add_import(column.type.data_type.__class__)

if column.default:
self.add_import(column.default)
@@ -375,7 +377,7 @@ def render_table(self, table: Table) -> str:

args.append(self.render_constraint(constraint))

for index in sorted(table.indexes, key=lambda i: i.name):
for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
# One-column indexes should be rendered as index=True on columns
if len(index.columns) > 1 or not uses_default_name(index):
args.append(self.render_index(index))
@@ -467,7 +469,7 @@ def render_column(

if isinstance(column.server_default, DefaultClause):
kwargs["server_default"] = render_callable(
"text", repr(column.server_default.arg.text)
"text", repr(cast(TextClause, column.server_default.arg).text)
)
elif isinstance(column.server_default, Computed):
expression = str(column.server_default.sqltext)
@@ -514,12 +516,18 @@ def render_column_type(self, coltype: object) -> str:

value = getattr(coltype, param.name, missing)
default = defaults.get(param.name, missing)
if isinstance(value, TextClause):
self.add_literal_import("sqlalchemy", "text")
rendered_value = render_callable("text", repr(value.text))
else:
rendered_value = repr(value)

if value is missing or value == default:
use_kwargs = True
elif use_kwargs:
kwargs[param.name] = repr(value)
kwargs[param.name] = rendered_value
else:
args.append(repr(value))
args.append(rendered_value)

vararg = next(
(
@@ -1072,6 +1080,7 @@ def generate_relationship_name(
preferred_name = column_names[0][:-3]

if "use_inflect" in self.options:
inflected_name: str | Literal[False]
if relationship.type in (
RelationshipType.ONE_TO_MANY,
RelationshipType.MANY_TO_MANY,
@@ -1166,7 +1175,7 @@ def render_table_args(self, table: Table) -> str:
args.append(self.render_constraint(constraint))

# Render indexes
for index in sorted(table.indexes, key=lambda i: i.name):
for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
if len(index.columns) > 1 or not uses_default_name(index):
args.append(self.render_index(index))

8 changes: 6 additions & 2 deletions src/sqlacodegen/utils.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import re
from collections.abc import Mapping
from typing import Any
from typing import Any, Literal, cast

from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
from sqlalchemy.engine import Connection, Engine
@@ -97,6 +97,7 @@ def uses_default_name(constraint: Constraint | Index) -> bool:
}
)

key: Literal["fk", "pk", "ix", "ck", "uq"]
if isinstance(constraint, Index):
key = "ix"
elif isinstance(constraint, CheckConstraint):
@@ -139,7 +140,10 @@ def uses_default_name(constraint: Constraint | Index) -> bool:
raise TypeError(f"Unknown constraint type: {constraint.__class__.__qualname__}")

try:
convention: str = table.metadata.naming_convention[key]
convention = cast(
Mapping[str, str],
table.metadata.naming_convention,
)[key]
return constraint.name == (convention % values)
except KeyError:
return False
74 changes: 74 additions & 0 deletions tests/test_generator_tables.py
Original file line number Diff line number Diff line change
@@ -205,6 +205,80 @@ def test_enum_detection(generator: CodeGenerator) -> None:
)


@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
def test_domain_text(generator: CodeGenerator) -> None:
Table(
"simple_items",
generator.metadata,
Column(
"postal_code",
postgresql.DOMAIN(
"us_postal_code",
Text,
constraint_name="valid_us_postal_code",
not_null=False,
check=text("VALUE ~ '^\\d{5}$' OR VALUE ~ '^\\d{5}-\\d{4}$'"),
),
nullable=False,
),
)

validate_code(
generator.generate(),
"""\
from sqlalchemy import Column, MetaData, Table, Text, text
from sqlalchemy.dialects.postgresql import DOMAIN

metadata = MetaData()


t_simple_items = Table(
'simple_items', metadata,
Column('postal_code', DOMAIN('us_postal_code', Text(), \
constraint_name='valid_us_postal_code', not_null=False, \
check=text("VALUE ~ '^\\\\d{5}$' OR VALUE ~ '^\\\\d{5}-\\\\d{4}$'")), nullable=False)
)
""",
)


@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
def test_domain_int(generator: CodeGenerator) -> None:
Table(
"simple_items",
generator.metadata,
Column(
"n",
postgresql.DOMAIN(
"positive_int",
INTEGER,
constraint_name="positive",
not_null=False,
check=text("VALUE > 0"),
),
nullable=False,
),
)

validate_code(
generator.generate(),
"""\
from sqlalchemy import Column, INTEGER, MetaData, Table, text
from sqlalchemy.dialects.postgresql import DOMAIN

metadata = MetaData()


t_simple_items = Table(
'simple_items', metadata,
Column('n', DOMAIN('positive_int', INTEGER(), \
constraint_name='positive', not_null=False, \
check=text('VALUE > 0')), nullable=False)
)
""",
)


@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
def test_column_adaptation(generator: CodeGenerator) -> None:
Table(