Skip to content

Commit

Permalink
Fix metadata usage for generating SQLAlchemy Table in SQLModel genera…
Browse files Browse the repository at this point in the history
…tor, fix agronholm#302
  • Loading branch information
THUzxj committed Dec 1, 2023
1 parent 7a77b21 commit d1fec15
Showing 1 changed file with 79 additions and 2 deletions.
81 changes: 79 additions & 2 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,84 @@ def __init__(
base_class_name=base_class_name,
)

def generate_models(self) -> list[Model]:
models_by_table_name: dict[str, Model] = {}

# Pick association tables from the metadata into their own set, don't process
# them normally
links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
for table in self.metadata.sorted_tables:
qualified_name = qualified_table_name(table)

# Link tables have exactly two foreign key constraints and all columns are
# involved in them
fk_constraints = sorted(
table.foreign_key_constraints, key=get_constraint_sort_key
)
if len(fk_constraints) == 2 and all(
col.foreign_keys for col in table.columns
):
model = models_by_table_name[qualified_name] = Model(table)
tablename = fk_constraints[0].elements[0].column.table.name
links[tablename].append(model)
continue

# Only form model classes for tables that have a primary key and are not
# association tables
if not table.primary_key:
models_by_table_name[qualified_name] = Model(table)
else:
model = ModelClass(table)
models_by_table_name[qualified_name] = model

# Fill in the columns
for column in table.c:
column_attr = ColumnAttribute(model, column)
model.columns.append(column_attr)

# Add relationships
for model in models_by_table_name.values():
if isinstance(model, ModelClass):
self.generate_relationships(
model, models_by_table_name, links[model.table.name]
)

# Nest inherited classes in their superclasses to ensure proper ordering
if "nojoined" not in self.options:
for model in list(models_by_table_name.values()):
if not isinstance(model, ModelClass):
continue

pk_column_names = {col.name for col in model.table.primary_key.columns}
for constraint in model.table.foreign_key_constraints:
if set(get_column_names(constraint)) == pk_column_names:
target = models_by_table_name[
qualified_table_name(constraint.elements[0].column.table)
]
if isinstance(target, ModelClass):
model.parent_class = target
target.children.append(model)

# Change base if we have both tables and model classes
if any(
not isinstance(model, ModelClass) for model in models_by_table_name.values()
):
TablesGenerator.generate_base(self)

# Collect the imports
self.collect_imports(models_by_table_name.values())

# Rename models and their attributes that conflict with imports or other
# attributes
global_names = {
name for namespace in self.imports.values() for name in namespace
}
for model in models_by_table_name.values():
self.generate_model_name(model, global_names)
global_names.add(model.name)

return list(models_by_table_name.values())

def generate_base(self) -> None:
self.base = Base(
literal_imports=[],
Expand All @@ -1538,7 +1616,6 @@ def generate_base(self) -> None:
def collect_imports(self, models: Iterable[Model]) -> None:
super(DeclarativeGenerator, self).collect_imports(models)
if any(isinstance(model, ModelClass) for model in models):
self.remove_literal_import("sqlalchemy", "MetaData")
self.add_literal_import("sqlmodel", "SQLModel")
self.add_literal_import("sqlmodel", "Field")

Expand Down Expand Up @@ -1570,7 +1647,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
self.add_import(python_type)

def render_module_variables(self, models: list[Model]) -> str:
declarations: list[str] = []
declarations: list[str] = self.base.declarations
if any(not isinstance(model, ModelClass) for model in models):
if self.base.table_metadata_declaration is not None:
declarations.append(self.base.table_metadata_declaration)
Expand Down

0 comments on commit d1fec15

Please sign in to comment.