From 9777ce33f2785013c014cabaf5572eed1105c592 Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Wed, 15 Jan 2025 17:23:39 -0500 Subject: [PATCH] Add type argument to SQLAlchemySchema and SQLAlchemyAutoSchema --- src/marshmallow_sqlalchemy/schema.py | 8 +++++--- tests/conftest.py | 21 ++++++++++++++++++--- tests/test_sqlalchemy_schema.py | 18 +++++++++--------- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/marshmallow_sqlalchemy/schema.py b/src/marshmallow_sqlalchemy/schema.py index 243a4d1..244155f 100644 --- a/src/marshmallow_sqlalchemy/schema.py +++ b/src/marshmallow_sqlalchemy/schema.py @@ -9,7 +9,7 @@ from .convert import ModelConverter from .exceptions import IncorrectSchemaTypeError -from .load_instance_mixin import LoadInstanceMixin +from .load_instance_mixin import LoadInstanceMixin, _ModelType # This isn't really a field; it's a placeholder for the metaclass. @@ -202,7 +202,7 @@ def get_declared_sqla_fields(cls, base_fields, converter, opts, dict_cls): class SQLAlchemySchema( - LoadInstanceMixin.Schema, Schema, metaclass=SQLAlchemySchemaMeta + LoadInstanceMixin.Schema[_ModelType], Schema, metaclass=SQLAlchemySchemaMeta ): """Schema for a SQLAlchemy model or table. Use together with `auto_field` to generate fields from columns. @@ -226,7 +226,9 @@ class Meta: OPTIONS_CLASS = SQLAlchemySchemaOpts -class SQLAlchemyAutoSchema(SQLAlchemySchema, metaclass=SQLAlchemyAutoSchemaMeta): +class SQLAlchemyAutoSchema( + SQLAlchemySchema[_ModelType], metaclass=SQLAlchemyAutoSchemaMeta +): """Schema that automatically generates fields from the columns of a SQLAlchemy model or table. diff --git a/tests/conftest.py b/tests/conftest.py index aecf685..4a64975 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,15 @@ from __future__ import annotations import datetime as dt +from dataclasses import dataclass from enum import Enum -from types import SimpleNamespace from typing import Any import pytest import sqlalchemy as sa from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import ( + DeclarativeMeta, Mapped, backref, column_property, @@ -60,8 +61,22 @@ def session(Base, models, engine): CourseLevel = Enum("CourseLevel", "PRIMARY SECONDARY") +@dataclass +class Models: + Course: type[DeclarativeMeta] + School: type[DeclarativeMeta] + Student: type[DeclarativeMeta] + Teacher: type[DeclarativeMeta] + SubstituteTeacher: type[DeclarativeMeta] + Paper: type[DeclarativeMeta] + GradedPaper: type[DeclarativeMeta] + Seminar: type[DeclarativeMeta] + Lecture: type[DeclarativeMeta] + Keyword: type[DeclarativeMeta] + + @pytest.fixture() -def models(Base: type): +def models(Base: type) -> Models: # models adapted from https://github.com/wtforms/wtforms-sqlalchemy/blob/master/tests/tests.py student_course = sa.Table( "student_course", @@ -221,7 +236,7 @@ class Lecture(Base): "kw", "keyword", creator=lambda kw: Keyword(keyword=kw) ) - return SimpleNamespace( + return Models( Course=Course, School=School, Student=Student, diff --git a/tests/test_sqlalchemy_schema.py b/tests/test_sqlalchemy_schema.py index 9688183..b5a1623 100644 --- a/tests/test_sqlalchemy_schema.py +++ b/tests/test_sqlalchemy_schema.py @@ -48,7 +48,7 @@ class EntityMixin: @pytest.fixture def sqla_auto_model_schema(models, request) -> SQLAlchemyAutoSchema: - class TeacherSchema(SQLAlchemyAutoSchema): + class TeacherSchema(SQLAlchemyAutoSchema[models.Teacher]): class Meta: model = models.Teacher @@ -73,7 +73,7 @@ class Meta: @pytest.fixture def sqla_schema_with_relationships(models, request) -> SQLAlchemySchema: - class TeacherSchema(EntityMixin, SQLAlchemySchema): + class TeacherSchema(EntityMixin, SQLAlchemySchema[models.Teacher]): class Meta: model = models.Teacher @@ -87,7 +87,7 @@ class Meta: @pytest.fixture def sqla_auto_model_schema_with_relationships(models, request) -> SQLAlchemyAutoSchema: - class TeacherSchema(SQLAlchemyAutoSchema): + class TeacherSchema(SQLAlchemyAutoSchema[models.Teacher]): class Meta: model = models.Teacher include_relationships = True @@ -102,7 +102,7 @@ class Meta: @pytest.fixture def sqla_schema_with_fks(models, request) -> SQLAlchemySchema: - class TeacherSchema(EntityMixin, SQLAlchemySchema): + class TeacherSchema(EntityMixin, SQLAlchemySchema[models.Teacher]): class Meta: model = models.Teacher @@ -115,7 +115,7 @@ class Meta: @pytest.fixture def sqla_auto_model_schema_with_fks(models, request) -> SQLAlchemyAutoSchema: - class TeacherSchema(SQLAlchemyAutoSchema): + class TeacherSchema(SQLAlchemyAutoSchema[models.Teacher]): class Meta: model = models.Teacher include_fk = True @@ -186,7 +186,7 @@ def test_load(schema): class TestLoadInstancePerSchemaInstance: @pytest.fixture def schema_no_load_instance(self, models, session): - class TeacherSchema(SQLAlchemySchema): + class TeacherSchema(SQLAlchemySchema[models.Teacher]): # type: ignore[name-defined] class Meta: model = models.Teacher sqla_session = session @@ -208,7 +208,7 @@ class Meta(schema_no_load_instance.Meta): # type: ignore[name-defined] @pytest.fixture def auto_schema_no_load_instance(self, models, session): - class TeacherSchema(SQLAlchemyAutoSchema): + class TeacherSchema(SQLAlchemyAutoSchema[models.Teacher]): # type: ignore[name-defined] class Meta: model = models.Teacher sqla_session = session @@ -302,7 +302,7 @@ class TeacherSchema(SQLAlchemySchema): # https://github.com/marshmallow-code/marshmallow-sqlalchemy/issues/190 def test_auto_schema_skips_synonyms(models): - class TeacherSchema(SQLAlchemyAutoSchema): + class TeacherSchema(SQLAlchemyAutoSchema[models.Teacher]): # type: ignore[name-defined] class Meta: model = models.Teacher include_fk = True @@ -327,7 +327,7 @@ class Meta: # Regresion test https://github.com/marshmallow-code/marshmallow-sqlalchemy/issues/306 def test_auto_field_works_with_ordered_flag(models): - class StudentSchema(SQLAlchemyAutoSchema): + class StudentSchema(SQLAlchemyAutoSchema[models.Student]): # type: ignore[name-defined] class Meta: model = models.Student ordered = True