Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: support for pydantic models - first cut #322

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
dist
build
venv*
.venv
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"SQLAlchemy >= 2.0.23",
"inflect >= 4.0.0",
"importlib_metadata; python_version < '3.10'",
"pydantic >= 2.0",
]
dynamic = ["version"]

Expand All @@ -55,6 +56,7 @@ pgvector = ["pgvector >= 0.2.4"]
tables = "sqlacodegen.generators:TablesGenerator"
declarative = "sqlacodegen.generators:DeclarativeGenerator"
dataclasses = "sqlacodegen.generators:DataclassGenerator"
pydanticmodels = "sqlacodegen.generators:PydanticGenerator"
sqlmodels = "sqlacodegen.generators:SQLModelGenerator"

[project.scripts]
Expand All @@ -65,7 +67,7 @@ version_scheme = "post-release"
local_scheme = "dirty-tag"

[tool.ruff]
select = [
lint.select = [
"E", "F", "W", # default Flake8
"I", # isort
"ISC", # flake8-implicit-str-concat
Expand Down Expand Up @@ -97,6 +99,9 @@ skip_missing_interpreters = true
minversion = 4.0.0

[testenv]
extras = test
extras =
test
sqlmodel
pydantic
commands = python -m pytest {posargs}
"""
229 changes: 228 additions & 1 deletion src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from keyword import iskeyword
from pprint import pformat
from textwrap import indent
from typing import Any, ClassVar
from typing import Any, ClassVar, Optional

import inflect
import pydantic
import sqlalchemy
from sqlalchemy import (
ARRAY,
Expand Down Expand Up @@ -1301,6 +1302,232 @@ def render_join(terms: list[JoinType]) -> str:
)


class PydanticGenerator(DeclarativeGenerator):
def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
options: Sequence[str],
*,
indentation: str = " ",
base_class_name: str = "BaseModel",
):
super().__init__(
metadata,
bind,
options,
indentation=indentation,
base_class_name=base_class_name,
)

def generate_base(self) -> None:
self.base = Base(
literal_imports=[
LiteralImport("pydantic", "BaseModel"),
LiteralImport("pydantic", "ConfigDict"),
],
declarations=[],
metadata_ref="",
)

def generate_models(self) -> list[Model]:
models_by_table_name: dict[str, Model] = {}

# Pick association tables from the metadata into their own set, don't process
# them normally
links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
for table in self.metadata.sorted_tables:
qualified_name = qualified_table_name(table)

# Link tables have exactly two foreign key constraints and all columns are
# involved in them
fk_constraints = sorted(
table.foreign_key_constraints, key=get_constraint_sort_key
)
if len(fk_constraints) == 2 and all(
col.foreign_keys for col in table.columns
):
model = models_by_table_name[qualified_name] = Model(table)
tablename = fk_constraints[0].elements[0].column.table.name
links[tablename].append(model)
continue

# Only difference from DeclarativeGenerator.generate_models
model = ModelClass(table)
models_by_table_name[qualified_name] = model

# Fill in the columns
for column in table.c:
column_attr = ColumnAttribute(model, column)
model.columns.append(column_attr)
# difference end

# Add relationships
for model in models_by_table_name.values():
if isinstance(model, ModelClass):
self.generate_relationships(
model, models_by_table_name, links[model.table.name]
)

# Nest inherited classes in their superclasses to ensure proper ordering
if "nojoined" not in self.options:
for model in list(models_by_table_name.values()):
if not isinstance(model, ModelClass):
continue

pk_column_names = {col.name for col in model.table.primary_key.columns}
for constraint in model.table.foreign_key_constraints:
if set(get_column_names(constraint)) == pk_column_names:
target = models_by_table_name[
qualified_table_name(constraint.elements[0].column.table)
]
if isinstance(target, ModelClass):
model.parent_class = target
target.children.append(model)

# Change base if we only have tables
if not any(
isinstance(model, ModelClass) for model in models_by_table_name.values()
):
super().generate_base()

# Collect the imports
self.collect_imports(models_by_table_name.values())

# Rename models and their attributes that conflict with imports or other
# attributes
global_names = {
name for namespace in self.imports.values() for name in namespace
}
for model in models_by_table_name.values():
self.generate_model_name(model, global_names)
global_names.add(model.name)

return list(models_by_table_name.values())

def collect_imports(self, models: Iterable[Model]) -> None:
# call TablesGenerator collect_imports bypassing DeclarativeGenerator
super(DeclarativeGenerator, self).collect_imports(models)

def collect_imports_for_model(self, model: Model) -> None:
for column in model.table.c:
self.collect_imports_for_column(column)

# for constraint in model.table.constraints:
# self.collect_imports_for_constraint(constraint)

# for index in model.table.indexes:
# self.collect_imports_for_constraint(index)

def collect_imports_for_column(self, column: Column[Any]) -> None:
self.add_import(column.type.python_type)

if isinstance(column.type, ARRAY):
# self.add_import(column.type.item_type.__class__)
print(
"collect_imports_for_column ARRAY",
column.type.item_type,
column.type.item_type.__class__,
)
...
elif isinstance(column.type, JSONB):
if (
not isinstance(column.type.astext_type, Text)
or column.type.astext_type.length is not None
):
print("collect_imports_for_column JSONB", column.type.astext_type)
# self.add_import(column.type.astext_type)
...

def add_import(self, obj: Any) -> None:
# Don't store builtin imports
if getattr(obj, "__module__", "builtins") == "builtins":
return

type_ = type(obj) if not isinstance(obj, type) else obj
pkgname: Optional[str] = None # noqa: UP007

if type_.__module__.startswith("sqlalchemy.dialects."):
pkgname = None
elif type_.__name__ in dir(sqlalchemy):
pkgname = None
elif type_.__name__ in dir(pydantic):
pkgname = "pydantic"
else:
pkgname = type_.__module__

if pkgname:
self.add_literal_import(pkgname, type_.__name__)

def render_class(self, model: ModelClass) -> str:
sections: list[str] = []

sections.append("model_config = ConfigDict(from_attributes=True)")

# Render column attributes
rendered_column_attributes: list[str] = []

for column_attr in model.columns:
rendered_column_attributes.append(self.render_column_attribute(column_attr))

if rendered_column_attributes:
sections.append("\n".join(rendered_column_attributes))

# Render relationship attributes
# rendered_relationship_attributes: list[str] = [
# self.render_relationship(relationship)
# for relationship in model.relationships
# ]

# if rendered_relationship_attributes:
# sections.append("\n".join(rendered_relationship_attributes))

declaration = self.render_class_declaration(model)
rendered_sections = "\n\n".join(
indent(section, self.indentation) for section in sections
)
return f"{declaration}\n{rendered_sections}"

def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
column = column_attr.column

try:
python_type = column.type.python_type
python_type_name = python_type.__name__
if python_type.__module__ == "builtins":
if python_type_name == "str" and column.type.length is not None:
column_python_type = self.render_column_type_str_length(
column.type.length
)
else:
column_python_type = python_type_name
else:
python_type_module = python_type.__module__
column_python_type = f"{python_type_module}.{python_type_name}"
self.add_module_import(python_type_module)
except NotImplementedError:
self.add_literal_import("typing", "Any")
column_python_type = "Any"

if column.nullable:
self.add_literal_import("typing", "Optional")
column_python_type = f"Optional[{column_python_type}]"
return f"{column_attr.name}: {column_python_type} = None"
else:
return f"{column_attr.name}: {column_python_type}"

def render_column_type_str_length(self, length: int) -> str:
self.add_literal_import("typing_extensions", "Annotated")
self.add_literal_import("pydantic", "StringConstraints")

return f"Annotated[str, StringConstraints(max_length={length})]"

def render_column(
self, column: Column[Any], show_name: bool, is_table: bool = False
) -> str:
return super().render_column(column, show_name, is_table)


class DataclassGenerator(DeclarativeGenerator):
def __init__(
self,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,34 @@ class Foo(SQLModel, table=True):
)


def test_cli_pydanticmodels(db_path: Path, tmp_path: Path) -> None:
output_path = tmp_path / "outfile"
subprocess.run(
[
"sqlacodegen",
f"sqlite:///{db_path}",
"--generator",
"pydanticmodels",
"--outfile",
str(output_path),
],
check=True,
)

assert (
output_path.read_text()
== """\
from pydantic import BaseModel, ConfigDict

class Foo(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: int
name: str
"""
)


def test_main() -> None:
expected_version = version("sqlacodegen")
completed = subprocess.run(
Expand Down
45 changes: 45 additions & 0 deletions tests/test_generator_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy.dialects import mysql
from sqlalchemy.engine import Engine
from sqlalchemy.schema import Column, MetaData, Table

from sqlacodegen.generators import CodeGenerator, PydanticGenerator

from .conftest import validate_code


@pytest.fixture
def generator(
request: FixtureRequest, metadata: MetaData, engine: Engine
) -> CodeGenerator:
options = getattr(request, "param", [])
return PydanticGenerator(metadata, engine, options)


@pytest.mark.parametrize("engine", ["mysql"], indirect=["engine"])
def test_mysql_column_types(generator: CodeGenerator) -> None:
Table(
"simple_items",
generator.metadata,
Column("id", mysql.INTEGER),
Column("name", mysql.VARCHAR(255)),
Column("text", mysql.TEXT),
)

validate_code(
generator.generate(),
"""\
from typing import Optional

from pydantic import BaseModel, ConfigDict, StringConstraints
from typing_extensions import Annotated

class SimpleItems(BaseModel):
model_config = ConfigDict(from_attributes=True)

id: Optional[int] = None
name: Optional[Annotated[str, StringConstraints(max_length=255)]] = None
text: Optional[str] = None
""",
)
Loading