From 015015224926216ee2903ef897e6359898a1cefe Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 24 May 2023 13:25:36 -0700 Subject: [PATCH 01/27] Possible way to mixin with DeclarativeBase --- .pdm-python | 1 + src/flask_sqlalchemy/extension.py | 6 +- src/flask_sqlalchemy/model.py | 104 +++++++++++++++++++++++++++++- tests/test_model.py | 43 ++++++++++++ 4 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 .pdm-python diff --git a/.pdm-python b/.pdm-python new file mode 100644 index 00000000..dd84e0de --- /dev/null +++ b/.pdm-python @@ -0,0 +1 @@ +/Users/pamelafox/flask-sqlalchemy/.venv/bin/python diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 27b0fd21..889d0379 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -460,7 +460,11 @@ def _make_declarative_base( .. versionchanged:: 2.3 ``model`` can be an already created declarative model class. """ - if not isinstance(model, sa_orm.DeclarativeMeta): + if ( + not isinstance(model, sa_orm.DeclarativeMeta) + and not issubclass(model, sa_orm.DeclarativeBase) + and not issubclass(model, sa_orm.DeclarativeBaseNoMeta) + ): metadata = self._make_metadata(None) model = sa_orm.declarative_base( metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 4f447e60..857e4b02 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -174,7 +174,8 @@ def should_set_tablename(cls: type) -> bool: joined-table inheritance. If no primary key is found, the name will be unset. """ if cls.__dict__.get("__abstract__", False) or not any( - isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:] + (isinstance(b, sa_orm.DeclarativeMeta) or b is sa_orm.DeclarativeBase) + for b in cls.__mro__[1:] ): return False @@ -188,7 +189,10 @@ def should_set_tablename(cls: type) -> bool: return not ( base is cls or base.__dict__.get("__abstract__", False) - or not isinstance(base, sa_orm.DeclarativeMeta) + or not ( + isinstance(base, sa_orm.DeclarativeMeta) + or base is sa_orm.DeclarativeBase + ) ) return True @@ -204,3 +208,99 @@ class DefaultMeta(BindMetaMixin, NameMetaMixin, sa_orm.DeclarativeMeta): """SQLAlchemy declarative metaclass that provides ``__bind_key__`` and ``__tablename__`` support. """ + + +class DefaultMixin: + """A mixin that provides Flask-SQLAlchemy default functionality: + * sets a model's ``__tablename__`` by converting the + ``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models + that do not otherwise define ``__tablename__``. If a model does not define a primary + key, it will not generate a name or ``__table__``, for single-table inheritance. + * sets a model's ``metadata`` based on its ``__bind_key__``. + If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is + ignored. If the ``metadata`` is the same as the parent model, it will not be set + directly on the child model. + * Provides a default repr based on the model's primary key. + """ + + __fsa__: SQLAlchemy + metadata: sa.MetaData + __tablename__: str + __table__: sa.Table + + def __init_subclass__(cls, **kwargs): + if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): + bind_key = getattr(cls, "__bind_key__", None) + parent_metadata = getattr(cls, "metadata", None) + metadata = cls.__fsa__._make_metadata(bind_key) + + if metadata is not parent_metadata: + cls.metadata = metadata + + if should_set_tablename(cls): + cls.__tablename__ = camel_to_snake_case(cls.__name__) + + super().__init_subclass__(**kwargs) + + # __table_cls__ has run. If no table was created, use the parent table. + if ( + "__tablename__" not in cls.__dict__ + and "__table__" in cls.__dict__ + and cls.__dict__["__table__"] is None + ): + del cls.__table__ + + @classmethod + def __table_cls__(cls, *args: t.Any, **kwargs: t.Any) -> sa.Table | None: + """This is called by SQLAlchemy during mapper setup. It determines the final + table object that the model will use. + + If no primary key is found, that indicates single-table inheritance, so no table + will be created and ``__tablename__`` will be unset. + """ + schema = kwargs.get("schema") + + if schema is None: + key = args[0] + else: + key = f"{schema}.{args[0]}" + + # Check if a table with this name already exists. Allows reflected tables to be + # applied to models by name. + if key in cls.metadata.tables: + return sa.Table(*args, **kwargs) + + # If a primary key is found, create a table for joined-table inheritance. + for arg in args: + if (isinstance(arg, sa.Column) and arg.primary_key) or isinstance( + arg, sa.PrimaryKeyConstraint + ): + return sa.Table(*args, **kwargs) + + # If no base classes define a table, return one that's missing a primary key + # so SQLAlchemy shows the correct error. + for base in cls.__mro__[1:-1]: + if "__table__" in base.__dict__: + break + else: + return sa.Table(*args, **kwargs) + + # Single-table inheritance, use the parent table name. __init__ will unset + # __table__ based on this. + if "__tablename__" in cls.__dict__: + del cls.__tablename__ + + return None + + def __repr__(self) -> str: + state = sa.inspect(self) + assert state is not None + + if state.transient: + pk = f"(transient {id(self)})" + elif state.pending: + pk = f"(pending {id(self)})" + else: + pk = ", ".join(map(str, state.identity)) + + return f"<{type(self).__name__} {pk}>" diff --git a/tests/test_model.py b/tests/test_model.py index 7d3aaa8a..ea42b500 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -6,9 +6,13 @@ import sqlalchemy as sa import sqlalchemy.orm as sa_orm from flask import Flask +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.model import DefaultMeta +from flask_sqlalchemy.model import DefaultMixin from flask_sqlalchemy.model import Model @@ -28,6 +32,45 @@ class CustomModel(Model): assert isinstance(db.Model, DefaultMeta) +@pytest.mark.usefixtures("app_ctx") +def test_custom_model_sqlalchemy20_class(app: Flask) -> None: + from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept + + class Base(DeclarativeBase): + pass + + db = SQLAlchemy(app, model_class=Base) + + # Check the model class is instantiated with the correct metaclass + assert issubclass(db.Model, Base) + assert isinstance(db.Model, type) + assert isinstance(db.Model, DeclarativeAttributeIntercept) + # Check that additional attributes are added to the model class + assert db.Model.query_class is db.Query + + # Now create a model that inherits from that declarative base + class Quiz(DefaultMixin, db.Model): + id: Mapped[int] = mapped_column( + db.Integer, primary_key=True, autoincrement=True + ) + title: Mapped[str] = mapped_column(db.String(255), nullable=False) + + assert Quiz.__tablename__ == "quiz" + assert isinstance(Quiz, DeclarativeAttributeIntercept) + + db.create_all() + quiz = Quiz(title="Python trivia") + db.session.add(quiz) + db.session.commit() + + # Check column types are correct + quiz_id: int = quiz.id + quiz_title: str = quiz.title + assert quiz_id == 1 + assert quiz_title == "Python trivia" + assert repr(quiz) == f"" + + @pytest.mark.usefixtures("app_ctx") @pytest.mark.parametrize("base", [Model, object]) def test_custom_declarative_class(app: Flask, base: t.Any) -> None: From 7d0f5e58d604b92fb873870ba491028145450acd Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 24 May 2023 13:26:47 -0700 Subject: [PATCH 02/27] Remove pdm-python file --- .pdm-python | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .pdm-python diff --git a/.pdm-python b/.pdm-python deleted file mode 100644 index dd84e0de..00000000 --- a/.pdm-python +++ /dev/null @@ -1 +0,0 @@ -/Users/pamelafox/flask-sqlalchemy/.venv/bin/python From d073fa673a5fdbc16b5c3bc4f8da5a190ced95ad Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 24 May 2023 17:10:18 -0700 Subject: [PATCH 03/27] Try different approach to mixins for 2.0 --- src/flask_sqlalchemy/extension.py | 16 ++- src/flask_sqlalchemy/model.py | 164 ++++++++++++++++-------------- tests/test_model.py | 13 ++- 3 files changed, 108 insertions(+), 85 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 889d0379..88e8d3ad 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import types import typing as t from weakref import WeakKeyDictionary @@ -14,8 +15,11 @@ from flask import has_app_context from .model import _QueryProperty +from .model import BindMixin from .model import DefaultMeta from .model import Model +from .model import NameMixin +from .model import ReprMixin from .pagination import Pagination from .pagination import SelectPagination from .query import Query @@ -439,7 +443,7 @@ def __new__( return Table def _make_declarative_base( - self, model: type[Model] | sa_orm.DeclarativeMeta + self, model: type[Model] | sa_orm.DeclarativeMeta | sa_orm.DeclarativeBase ) -> type[t.Any]: """Create a SQLAlchemy declarative model class. The result is available as :attr:`Model`. @@ -460,7 +464,15 @@ def _make_declarative_base( .. versionchanged:: 2.3 ``model`` can be an already created declarative model class. """ - if ( + if model is sa_orm.DeclarativeBase: + body = {"__fsa__": self} + model = types.new_class( + "Base", + (BindMixin, NameMixin, ReprMixin, sa_orm.DeclarativeBase), + {"metaclass": type(sa_orm.DeclarativeBase)}, + lambda ns: ns.update(body), + ) + elif ( not isinstance(model, sa_orm.DeclarativeMeta) and not issubclass(model, sa_orm.DeclarativeBase) and not issubclass(model, sa_orm.DeclarativeBaseNoMeta) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 857e4b02..da60ff95 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -67,6 +67,21 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {pk}>" +class ReprMixin: + def __repr__(self) -> str: + state = sa.inspect(self) + assert state is not None + + if state.transient: + pk = f"(transient {id(self)})" + elif state.pending: + pk = f"(pending {id(self)})" + else: + pk = ", ".join(map(str, state.identity)) + + return f"<{type(self).__name__} {pk}>" + + class BindMetaMixin(type): """Metaclass mixin that sets a model's ``metadata`` based on its ``__bind_key__``. @@ -92,6 +107,29 @@ def __init__( super().__init__(name, bases, d, **kwargs) +class BindMixin: + """DeclarativeBase mixin to set a model's ``metadata`` based on ``__bind_key__``. + + If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is + ignored. If the ``metadata`` is the same as the parent model, it will not be set + directly on the child model. + """ + + __fsa__: SQLAlchemy + metadata: sa.MetaData + + def __init_subclass__(cls, **kwargs): + if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): + bind_key = getattr(cls, "__bind_key__", None) + parent_metadata = getattr(cls, "metadata", None) + metadata = cls.__fsa__._make_metadata(bind_key) + + if metadata is not parent_metadata: + cls.metadata = metadata + + super().__init_subclass__(**kwargs) + + class NameMetaMixin(type): """Metaclass mixin that sets a model's ``__tablename__`` by converting the ``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models @@ -161,82 +199,19 @@ def __table_cls__(cls, *args: t.Any, **kwargs: t.Any) -> sa.Table | None: return None -def should_set_tablename(cls: type) -> bool: - """Determine whether ``__tablename__`` should be generated for a model. - - - If no class in the MRO sets a name, one should be generated. - - If a declared attr is found, it should be used instead. - - If a name is found, it should be used if the class is a mixin, otherwise one - should be generated. - - Abstract models should not have one generated. - - Later, ``__table_cls__`` will determine if the model looks like single or - joined-table inheritance. If no primary key is found, the name will be unset. - """ - if cls.__dict__.get("__abstract__", False) or not any( - (isinstance(b, sa_orm.DeclarativeMeta) or b is sa_orm.DeclarativeBase) - for b in cls.__mro__[1:] - ): - return False - - for base in cls.__mro__: - if "__tablename__" not in base.__dict__: - continue - - if isinstance(base.__dict__["__tablename__"], sa_orm.declared_attr): - return False - - return not ( - base is cls - or base.__dict__.get("__abstract__", False) - or not ( - isinstance(base, sa_orm.DeclarativeMeta) - or base is sa_orm.DeclarativeBase - ) - ) - - return True - - -def camel_to_snake_case(name: str) -> str: - """Convert a ``CamelCase`` name to ``snake_case``.""" - name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name) - return name.lower().lstrip("_") - - -class DefaultMeta(BindMetaMixin, NameMetaMixin, sa_orm.DeclarativeMeta): - """SQLAlchemy declarative metaclass that provides ``__bind_key__`` and - ``__tablename__`` support. - """ - - -class DefaultMixin: - """A mixin that provides Flask-SQLAlchemy default functionality: - * sets a model's ``__tablename__`` by converting the +class NameMixin: + """DeclarativeBase mixin that sets a model's ``__tablename__`` by converting the ``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models that do not otherwise define ``__tablename__``. If a model does not define a primary key, it will not generate a name or ``__table__``, for single-table inheritance. - * sets a model's ``metadata`` based on its ``__bind_key__``. - If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is - ignored. If the ``metadata`` is the same as the parent model, it will not be set - directly on the child model. - * Provides a default repr based on the model's primary key. """ - __fsa__: SQLAlchemy metadata: sa.MetaData __tablename__: str __table__: sa.Table + @classmethod def __init_subclass__(cls, **kwargs): - if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): - bind_key = getattr(cls, "__bind_key__", None) - parent_metadata = getattr(cls, "metadata", None) - metadata = cls.__fsa__._make_metadata(bind_key) - - if metadata is not parent_metadata: - cls.metadata = metadata - if should_set_tablename(cls): cls.__tablename__ = camel_to_snake_case(cls.__name__) @@ -292,15 +267,52 @@ def __table_cls__(cls, *args: t.Any, **kwargs: t.Any) -> sa.Table | None: return None - def __repr__(self) -> str: - state = sa.inspect(self) - assert state is not None - if state.transient: - pk = f"(transient {id(self)})" - elif state.pending: - pk = f"(pending {id(self)})" - else: - pk = ", ".join(map(str, state.identity)) +def should_set_tablename(cls: type) -> bool: + """Determine whether ``__tablename__`` should be generated for a model. - return f"<{type(self).__name__} {pk}>" + - If no class in the MRO sets a name, one should be generated. + - If a declared attr is found, it should be used instead. + - If a name is found, it should be used if the class is a mixin, otherwise one + should be generated. + - Abstract models should not have one generated. + + Later, ``__table_cls__`` will determine if the model looks like single or + joined-table inheritance. If no primary key is found, the name will be unset. + """ + + # TODO: or not any(isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:]) \ + if cls.__dict__.get("__abstract__", False) or any( + b is sa_orm.DeclarativeBase for b in cls.__bases__ + ): + return False + + for base in cls.__mro__: + if "__tablename__" not in base.__dict__: + continue + + if isinstance(base.__dict__["__tablename__"], sa_orm.declared_attr): + return False + + return not ( + base is cls + or base.__dict__.get("__abstract__", False) + or not ( + isinstance(base, sa_orm.DeclarativeMeta) + or isinstance(base, sa_orm.decl_api.DeclarativeAttributeIntercept) + ) + ) + + return True + + +def camel_to_snake_case(name: str) -> str: + """Convert a ``CamelCase`` name to ``snake_case``.""" + name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name) + return name.lower().lstrip("_") + + +class DefaultMeta(BindMetaMixin, NameMetaMixin, sa_orm.DeclarativeMeta): + """SQLAlchemy declarative metaclass that provides ``__bind_key__`` and + ``__tablename__`` support. + """ diff --git a/tests/test_model.py b/tests/test_model.py index ea42b500..bcc623dd 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,7 +12,6 @@ from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.model import DefaultMeta -from flask_sqlalchemy.model import DefaultMixin from flask_sqlalchemy.model import Model @@ -36,26 +35,26 @@ class CustomModel(Model): def test_custom_model_sqlalchemy20_class(app: Flask) -> None: from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept - class Base(DeclarativeBase): - pass - - db = SQLAlchemy(app, model_class=Base) + app.config["SQLALCHEMY_BINDS"] = {"other": "sqlite://"} + db = SQLAlchemy(app, model_class=DeclarativeBase) # Check the model class is instantiated with the correct metaclass - assert issubclass(db.Model, Base) + assert issubclass(db.Model, DeclarativeBase) assert isinstance(db.Model, type) assert isinstance(db.Model, DeclarativeAttributeIntercept) # Check that additional attributes are added to the model class assert db.Model.query_class is db.Query # Now create a model that inherits from that declarative base - class Quiz(DefaultMixin, db.Model): + class Quiz(db.Model): + __bind_key__ = "other" id: Mapped[int] = mapped_column( db.Integer, primary_key=True, autoincrement=True ) title: Mapped[str] = mapped_column(db.String(255), nullable=False) assert Quiz.__tablename__ == "quiz" + assert Quiz.metadata is db.metadatas["other"] assert isinstance(Quiz, DeclarativeAttributeIntercept) db.create_all() From 5ff15bc239ea6b07e8ee32158319d3887b14149f Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 25 May 2023 11:41:31 -0700 Subject: [PATCH 04/27] Support MappedAsDataClass --- src/flask_sqlalchemy/extension.py | 23 +++++++----- src/flask_sqlalchemy/model.py | 3 +- tests/test_model.py | 60 +++++++++++++++++++++++++------ 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 88e8d3ad..34cfe4bb 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -464,19 +464,26 @@ def _make_declarative_base( .. versionchanged:: 2.3 ``model`` can be an already created declarative model class. """ - if model is sa_orm.DeclarativeBase: + declarative_bases = [ + b + for b in model.__bases__ + if issubclass(b, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) + ] + if len(declarative_bases) > 1: + # raise error if more than one declarative base is found + raise ValueError( + "Only one declarative base can be passed to SQLAlchemy." + " Got: {}".format(model.__bases__) + ) + elif len(declarative_bases) == 1: body = {"__fsa__": self} model = types.new_class( "Base", - (BindMixin, NameMixin, ReprMixin, sa_orm.DeclarativeBase), - {"metaclass": type(sa_orm.DeclarativeBase)}, + (BindMixin, NameMixin, ReprMixin, *model.__bases__), + {"metaclass": type(declarative_bases[0])}, lambda ns: ns.update(body), ) - elif ( - not isinstance(model, sa_orm.DeclarativeMeta) - and not issubclass(model, sa_orm.DeclarativeBase) - and not issubclass(model, sa_orm.DeclarativeBaseNoMeta) - ): + elif not isinstance(model, sa_orm.DeclarativeMeta): metadata = self._make_metadata(None) model = sa_orm.declarative_base( metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index da60ff95..58d5af57 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -283,7 +283,8 @@ def should_set_tablename(cls: type) -> bool: # TODO: or not any(isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:]) \ if cls.__dict__.get("__abstract__", False) or any( - b is sa_orm.DeclarativeBase for b in cls.__bases__ + (b is sa_orm.DeclarativeBase or b is sa_orm.DeclarativeBaseNoMeta) + for b in cls.__bases__ ): return False diff --git a/tests/test_model.py b/tests/test_model.py index bcc623dd..8506dcd6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import types import typing as t import pytest @@ -7,8 +8,10 @@ import sqlalchemy.orm as sa_orm from flask import Flask from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBaseNoMeta from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.model import DefaultMeta @@ -31,31 +34,63 @@ class CustomModel(Model): assert isinstance(db.Model, DefaultMeta) +test_classes = [ + types.new_class( + "BaseModel", (DeclarativeBase,), {"metaclass": type(sa.orm.DeclarativeBase)} + ), + types.new_class( + "BaseMappedModel", + (MappedAsDataclass, DeclarativeBase), + {"metaclass": type(sa.orm.DeclarativeBase)}, + ), + types.new_class( + "BaseModel", + (DeclarativeBaseNoMeta,), + {"metaclass": type(sa.orm.DeclarativeBaseNoMeta)}, + ), + types.new_class( + "BaseModel", + ( + MappedAsDataclass, + DeclarativeBaseNoMeta, + ), + {"metaclass": type(sa.orm.DeclarativeBaseNoMeta)}, + ), +] + + @pytest.mark.usefixtures("app_ctx") -def test_custom_model_sqlalchemy20_class(app: Flask) -> None: +@pytest.mark.parametrize("base", test_classes) +def test_sqlalchemy20(app: Flask, base: object) -> None: from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept app.config["SQLALCHEMY_BINDS"] = {"other": "sqlite://"} - db = SQLAlchemy(app, model_class=DeclarativeBase) + db = SQLAlchemy(app, model_class=base) # Check the model class is instantiated with the correct metaclass - assert issubclass(db.Model, DeclarativeBase) - assert isinstance(db.Model, type) - assert isinstance(db.Model, DeclarativeAttributeIntercept) + if issubclass(db.Model, DeclarativeBase) or issubclass(db.Model, MappedAsDataclass): + assert isinstance(db.Model, DeclarativeAttributeIntercept) + elif issubclass(db.Model, DeclarativeBaseNoMeta) and not issubclass( + db.Model, MappedAsDataclass + ): + assert not isinstance(db.Model, DeclarativeAttributeIntercept) # Check that additional attributes are added to the model class assert db.Model.query_class is db.Query + assert db.Model.__fsa__ is db + + if issubclass(base, MappedAsDataclass): + id_column = mapped_column(init=False, primary_key=True, autoincrement=True) + else: + id_column = mapped_column(primary_key=True, autoincrement=True) # Now create a model that inherits from that declarative base class Quiz(db.Model): __bind_key__ = "other" - id: Mapped[int] = mapped_column( - db.Integer, primary_key=True, autoincrement=True - ) + id: Mapped[int] = id_column title: Mapped[str] = mapped_column(db.String(255), nullable=False) assert Quiz.__tablename__ == "quiz" assert Quiz.metadata is db.metadatas["other"] - assert isinstance(Quiz, DeclarativeAttributeIntercept) db.create_all() quiz = Quiz(title="Python trivia") @@ -67,7 +102,12 @@ class Quiz(db.Model): quiz_title: str = quiz.title assert quiz_id == 1 assert quiz_title == "Python trivia" - assert repr(quiz) == f"" + if issubclass(base, MappedAsDataclass): + assert ( + repr(quiz) == "test_sqlalchemy20..Quiz(id=1, title='Python trivia')" + ) + else: + assert repr(quiz) == f"" @pytest.mark.usefixtures("app_ctx") From 8d78e37b636f2cf43067ca432c7991cc3acf257e Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 25 May 2023 11:56:37 -0700 Subject: [PATCH 05/27] Update should check --- src/flask_sqlalchemy/extension.py | 2 +- src/flask_sqlalchemy/model.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 34cfe4bb..ffc8f8b2 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -443,7 +443,7 @@ def __new__( return Table def _make_declarative_base( - self, model: type[Model] | sa_orm.DeclarativeMeta | sa_orm.DeclarativeBase + self, model: type[Model] | sa_orm.DeclarativeMeta ) -> type[t.Any]: """Create a SQLAlchemy declarative model class. The result is available as :attr:`Model`. diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 58d5af57..823abc17 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -280,11 +280,17 @@ def should_set_tablename(cls: type) -> bool: Later, ``__table_cls__`` will determine if the model looks like single or joined-table inheritance. If no primary key is found, the name will be unset. """ - - # TODO: or not any(isinstance(b, sa.orm.DeclarativeMeta) for b in cls.__mro__[1:]) \ - if cls.__dict__.get("__abstract__", False) or any( - (b is sa_orm.DeclarativeBase or b is sa_orm.DeclarativeBaseNoMeta) - for b in cls.__bases__ + uses_2pt0 = issubclass(cls, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) + if ( + cls.__dict__.get("__abstract__", False) + or ( + not uses_2pt0 + and not any(isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:]) + ) + or any( + (b is sa_orm.DeclarativeBase or b is sa_orm.DeclarativeBaseNoMeta) + for b in cls.__bases__ + ) ): return False From 41fdd5d80f547f2b63233bec04fe1fde37cfe97b Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 31 May 2023 12:40:07 -0700 Subject: [PATCH 06/27] Parameterizing all the tests --- src/flask_sqlalchemy/extension.py | 5 +-- src/flask_sqlalchemy/model.py | 17 +------- tests/conftest.py | 66 ++++++++++++++++++++++++++++--- tests/test_model.py | 18 ++++++--- 4 files changed, 76 insertions(+), 30 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index ffc8f8b2..80ef7674 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -19,7 +19,6 @@ from .model import DefaultMeta from .model import Model from .model import NameMixin -from .model import ReprMixin from .pagination import Pagination from .pagination import SelectPagination from .query import Query @@ -478,8 +477,8 @@ def _make_declarative_base( elif len(declarative_bases) == 1: body = {"__fsa__": self} model = types.new_class( - "Base", - (BindMixin, NameMixin, ReprMixin, *model.__bases__), + "FlaskSQLAlchemyBase", + (BindMixin, NameMixin, Model, *model.__bases__), {"metaclass": type(declarative_bases[0])}, lambda ns: ns.update(body), ) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 823abc17..1ec98e1b 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -67,21 +67,6 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {pk}>" -class ReprMixin: - def __repr__(self) -> str: - state = sa.inspect(self) - assert state is not None - - if state.transient: - pk = f"(transient {id(self)})" - elif state.pending: - pk = f"(pending {id(self)})" - else: - pk = ", ".join(map(str, state.identity)) - - return f"<{type(self).__name__} {pk}>" - - class BindMetaMixin(type): """Metaclass mixin that sets a model's ``metadata`` based on its ``__bind_key__``. @@ -118,6 +103,7 @@ class BindMixin: __fsa__: SQLAlchemy metadata: sa.MetaData + @classmethod def __init_subclass__(cls, **kwargs): if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): bind_key = getattr(cls, "__bind_key__", None) @@ -307,6 +293,7 @@ def should_set_tablename(cls: type) -> bool: or not ( isinstance(base, sa_orm.DeclarativeMeta) or isinstance(base, sa_orm.decl_api.DeclarativeAttributeIntercept) + or issubclass(base, sa_orm.DeclarativeBaseNoMeta) ) ) diff --git a/tests/conftest.py b/tests/conftest.py index 320225ca..b767ff99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,15 @@ from __future__ import annotations +import types import typing as t from pathlib import Path import pytest import sqlalchemy as sa +import sqlalchemy.orm as sa_orm from flask import Flask from flask.ctx import AppContext +from sqlalchemy.orm import Mapped from flask_sqlalchemy import SQLAlchemy @@ -25,16 +28,67 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: yield ctx -@pytest.fixture -def db(app: Flask) -> SQLAlchemy: - return SQLAlchemy(app) +test_classes = [ + None, + types.new_class( + "BaseDeclarativeBase", + (sa_orm.DeclarativeBase,), + {"metaclass": type(sa_orm.DeclarativeBase)}, + ), + types.new_class( + "BaseDataclassDeclarativeBase", + (sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase), + {"metaclass": type(sa_orm.DeclarativeBase)}, + ), + types.new_class( + "BaseDeclarativeBaseNoMeta", + (sa_orm.DeclarativeBaseNoMeta,), + {"metaclass": type(sa_orm.DeclarativeBaseNoMeta)}, + ), + types.new_class( + "BaseDataclassDeclarativeBaseNoMeta", + ( + sa_orm.MappedAsDataclass, + sa_orm.DeclarativeBaseNoMeta, + ), + {"metaclass": type(sa_orm.DeclarativeBaseNoMeta)}, + ), +] + + +@pytest.fixture(params=test_classes) +def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy: + if request.param is not None: + return SQLAlchemy(app, model_class=request.param) + else: + return SQLAlchemy(app) @pytest.fixture def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]: - class Todo(db.Model): - id = sa.Column(sa.Integer, primary_key=True) - title = sa.Column(sa.String) + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): + + class Todo(db.Model): + id: Mapped[int] = sa_orm.mapped_column( + sa.Integer, init=False, primary_key=True + ) + title: Mapped[str] = sa_orm.mapped_column( + sa.String, nullable=True, default=None + ) + + elif issubclass( + db.Model, (sa_orm.DeclarativeBaseNoMeta, sa_orm.DeclarativeBaseNoMeta) + ): + + class Todo(db.Model): + id: Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + title: Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True) + + else: + + class Todo(db.Model): + id: sa.Column = sa.Column(sa.Integer, primary_key=True) + title: sa.Column = sa.Column(sa.String) with app.app_context(): db.create_all() diff --git a/tests/test_model.py b/tests/test_model.py index 8506dcd6..75219136 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,7 +18,9 @@ from flask_sqlalchemy.model import Model -def test_default_model_class(db: SQLAlchemy) -> None: +def test_default_model_class(app: Flask) -> None: + db = SQLAlchemy(app) + assert db.Model.query_class is db.Query assert db.Model.metadata is db.metadata assert issubclass(db.Model, Model) @@ -130,8 +132,12 @@ class User(db.Model): db.create_all() user = User() - assert repr(user) == f"" - db.session.add(user) - assert repr(user) == f"" - db.session.flush() - assert repr(user) == f"" + + if issubclass(db.Model, MappedAsDataclass): + assert repr(user) == "test_model_repr..User()" + else: + assert repr(user) == f"" + db.session.add(user) + assert repr(user) == f"" + db.session.flush() + assert repr(user) == f"" From 31e8e2c4a0680f5f07316ecaeeb9dbbc9a9f8db9 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 31 May 2023 12:41:18 -0700 Subject: [PATCH 07/27] Remove redundant test --- tests/test_model.py | 84 +-------------------------------------------- 1 file changed, 1 insertion(+), 83 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 75219136..80ced173 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,17 +1,11 @@ from __future__ import annotations -import types import typing as t import pytest import sqlalchemy as sa import sqlalchemy.orm as sa_orm from flask import Flask -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import DeclarativeBaseNoMeta -from sqlalchemy.orm import Mapped -from sqlalchemy.orm import mapped_column -from sqlalchemy.orm import MappedAsDataclass from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy.model import DefaultMeta @@ -36,82 +30,6 @@ class CustomModel(Model): assert isinstance(db.Model, DefaultMeta) -test_classes = [ - types.new_class( - "BaseModel", (DeclarativeBase,), {"metaclass": type(sa.orm.DeclarativeBase)} - ), - types.new_class( - "BaseMappedModel", - (MappedAsDataclass, DeclarativeBase), - {"metaclass": type(sa.orm.DeclarativeBase)}, - ), - types.new_class( - "BaseModel", - (DeclarativeBaseNoMeta,), - {"metaclass": type(sa.orm.DeclarativeBaseNoMeta)}, - ), - types.new_class( - "BaseModel", - ( - MappedAsDataclass, - DeclarativeBaseNoMeta, - ), - {"metaclass": type(sa.orm.DeclarativeBaseNoMeta)}, - ), -] - - -@pytest.mark.usefixtures("app_ctx") -@pytest.mark.parametrize("base", test_classes) -def test_sqlalchemy20(app: Flask, base: object) -> None: - from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept - - app.config["SQLALCHEMY_BINDS"] = {"other": "sqlite://"} - db = SQLAlchemy(app, model_class=base) - - # Check the model class is instantiated with the correct metaclass - if issubclass(db.Model, DeclarativeBase) or issubclass(db.Model, MappedAsDataclass): - assert isinstance(db.Model, DeclarativeAttributeIntercept) - elif issubclass(db.Model, DeclarativeBaseNoMeta) and not issubclass( - db.Model, MappedAsDataclass - ): - assert not isinstance(db.Model, DeclarativeAttributeIntercept) - # Check that additional attributes are added to the model class - assert db.Model.query_class is db.Query - assert db.Model.__fsa__ is db - - if issubclass(base, MappedAsDataclass): - id_column = mapped_column(init=False, primary_key=True, autoincrement=True) - else: - id_column = mapped_column(primary_key=True, autoincrement=True) - - # Now create a model that inherits from that declarative base - class Quiz(db.Model): - __bind_key__ = "other" - id: Mapped[int] = id_column - title: Mapped[str] = mapped_column(db.String(255), nullable=False) - - assert Quiz.__tablename__ == "quiz" - assert Quiz.metadata is db.metadatas["other"] - - db.create_all() - quiz = Quiz(title="Python trivia") - db.session.add(quiz) - db.session.commit() - - # Check column types are correct - quiz_id: int = quiz.id - quiz_title: str = quiz.title - assert quiz_id == 1 - assert quiz_title == "Python trivia" - if issubclass(base, MappedAsDataclass): - assert ( - repr(quiz) == "test_sqlalchemy20..Quiz(id=1, title='Python trivia')" - ) - else: - assert repr(quiz) == f"" - - @pytest.mark.usefixtures("app_ctx") @pytest.mark.parametrize("base", [Model, object]) def test_custom_declarative_class(app: Flask, base: t.Any) -> None: @@ -133,7 +51,7 @@ class User(db.Model): db.create_all() user = User() - if issubclass(db.Model, MappedAsDataclass): + if issubclass(db.Model, sa_orm.MappedAsDataclass): assert repr(user) == "test_model_repr..User()" else: assert repr(user) == f"" From b536d28185e47a87ef42fe8fb36f71f3cb01517f Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 31 May 2023 15:16:51 -0700 Subject: [PATCH 08/27] Updates to docs --- docs/api.rst | 19 +++++++++ docs/models.rst | 13 +++++++ docs/quickstart.rst | 64 ++++++++++++++++++++++++++----- src/flask_sqlalchemy/extension.py | 10 ++++- tests/conftest.py | 9 ++--- 5 files changed, 98 insertions(+), 17 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index da739922..6bcc5fb7 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -32,6 +32,13 @@ Model If the ``__table__`` or ``__tablename__`` is set explicitly, that will be used instead. +Metaclass mixins (SQLAlchemy 1.x) +--------------------------------- + +If your code uses the SQLAlchemy 1.x API (the default for code that doesn't specify a ``model_class``), +then these mixins are automatically applied to the ``Model`` class. They can also be used +directly to create custom metaclasses. See :doc:`customizing` for more information. + .. autoclass:: DefaultMeta .. autoclass:: BindMetaMixin @@ -39,6 +46,18 @@ Model .. autoclass:: NameMetaMixin +Base class mixins (SQLAlchemy 2.x) +---------------------------------- + +If your code uses the SQLAlchemy 2.x API by passing a subclass of ``DeclarativeBase`` +or ``DeclarativeBaseNoMeta`` as the ``model_class``, then the following classes +are automatically added as additional base classes. + +.. autoclass:: BindMixin + +.. autoclass:: NameMixin + + Session ------- diff --git a/docs/models.rst b/docs/models.rst index c434a4aa..56a7b16a 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -30,6 +30,19 @@ For convenience, the extension object provides access to names in the ``sqlalche ``sqlalchemy.orm`` modules. So you can use ``db.Column`` instead of importing and using ``sqlalchemy.Column``, although the two are equivalent. +It's also possible to use the SQLAlchemy 2.x style of defining models, +as long as you initialized the extension with an appropriate 2.x model base class +(as described in the quickstart). + +.. code-block:: python + + from sqlalchemy.orm import Mapped, mapped_column + + class User(db.Model): + id: Mapped[int] = mapped_column(db.Integer, primary_key=True) + username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) + email: Mapped[str] = mapped_column(db.String) + Defining a model does not create it in the database. Use :meth:`~.SQLAlchemy.create_all` to create the models and tables after defining them. If you define models in submodules, you must import them so that SQLAlchemy knows about them before calling ``create_all``. diff --git a/docs/quickstart.rst b/docs/quickstart.rst index ae5e7176..6cee5db2 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -40,9 +40,46 @@ For example, to install or update the latest version using pip: .. _PyPI: https://pypi.org/project/Flask-SQLAlchemy/ +Initialize the Extension +------------------------ + +First create the ``db`` object using the ``SQLAlchemy`` constructor. + +.. code-block:: python + + from flask import Flask + from flask_sqlalchemy import SQLAlchemy + from sqlalchemy.orm import DeclarativeBase + + db = SQLAlchemy() + + +By default, this extension assumes that you are using the SQLAlchemy 1.x API for defining models. + +To use the new SQLAlchemy 2.x API, pass a subclass of either ``DeclarativeBase`` or ``DeclarativeBaseNoMeta`` +to the constructor. + +.. code-block:: python + + from flask import Flask + from flask_sqlalchemy import SQLAlchemy + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass + + db = SQLAlchemy(model_class=Base) + +Once constructed, the ``db`` object gives you access to the :attr:`db.Model <.SQLAlchemy.Model>` class to +define models, and the :attr:`db.session <.SQLAlchemy.session>` to execute queries. + +The :class:`SQLAlchemy` object also takes additional arguments to customize the +objects it manages. + Configure the Extension ----------------------- +The next step is to connect the extension to your Flask app. The only required Flask app config is the :data:`.SQLALCHEMY_DATABASE_URI` key. That is a connection string that tells SQLAlchemy what database to connect to. @@ -53,11 +90,6 @@ which is stored in the app's instance folder. .. code-block:: python - from flask import Flask - from flask_sqlalchemy import SQLAlchemy - - # create the extension - db = SQLAlchemy() # create the app app = Flask(__name__) # configure the SQLite database, relative to the app instance folder @@ -65,12 +97,8 @@ which is stored in the app's instance folder. # initialize the app with the extension db.init_app(app) -The ``db`` object gives you access to the :attr:`db.Model <.SQLAlchemy.Model>` class to -define models, and the :attr:`db.session <.SQLAlchemy.session>` to execute queries. - See :doc:`config` for an explanation of connections strings and what other configuration -keys are used. The :class:`SQLAlchemy` object also takes some arguments to customize the -objects it manages. +keys are used. Define Models @@ -81,6 +109,8 @@ Subclass ``db.Model`` to define a model class. The ``db`` object makes the names The model will generate a table name by converting the ``CamelCase`` class name to ``snake_case``. +This example uses the SQLAlchemy 1.x style of defining models: + .. code-block:: python class User(db.Model): @@ -90,6 +120,20 @@ The model will generate a table name by converting the ``CamelCase`` class name The table name ``"user"`` will automatically be assigned to the model's table. +It's also possible to use the SQLAlchemy 2.x style of defining models, +as long as you initialized the extension with an appropriate 2.x model base class +as described above. + +.. code-block:: python + + from sqlalchemy.orm import Mapped, mapped_column + + class User(db.Model): + id: Mapped[int] = mapped_column(db.Integer, primary_key=True) + username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) + email: Mapped[str] = mapped_column(db.String) + + See :doc:`models` for more information about defining and creating models and tables. diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 80ef7674..8821b05f 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -202,9 +202,15 @@ def __init__( database engine. Otherwise, it will use the default :attr:`metadata` and :attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``. - Customize this by subclassing :class:`.Model` and passing the ``model_class`` - parameter to the extension. A fully created declarative model class can be + For code using the SQLAlchemy 1.x API, customize this model by subclassing + :class:`.Model` and passing the ``model_class`` parameter to the extension. + A fully created declarative model class can be passed as well, to use a custom metaclass. + + For code using the SQLAlchemy 2.x API, customize this model by subclassing + :class:`sqlalchemy.orm.DeclarativeBase` or + :class:`sqlalchemy.orm.DeclarativeBaseNoMeta` + and passing the ``model_class`` parameter to the extension. """ if engine_options is None: diff --git a/tests/conftest.py b/tests/conftest.py index b767ff99..2634dae2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,6 @@ import sqlalchemy.orm as sa_orm from flask import Flask from flask.ctx import AppContext -from sqlalchemy.orm import Mapped from flask_sqlalchemy import SQLAlchemy @@ -69,10 +68,10 @@ def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]: if issubclass(db.Model, (sa_orm.MappedAsDataclass)): class Todo(db.Model): - id: Mapped[int] = sa_orm.mapped_column( + id: sa_orm.Mapped[int] = sa_orm.mapped_column( sa.Integer, init=False, primary_key=True ) - title: Mapped[str] = sa_orm.mapped_column( + title: sa_orm.Mapped[str] = sa_orm.mapped_column( sa.String, nullable=True, default=None ) @@ -81,8 +80,8 @@ class Todo(db.Model): ): class Todo(db.Model): - id: Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) - title: Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True) + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True) else: From af5096d6073a266b31909a51e81a1b742c2c3991 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Thu, 1 Jun 2023 13:34:19 -0700 Subject: [PATCH 09/27] Getting types to work (mostly) --- CHANGES.rst | 1 + src/flask_sqlalchemy/extension.py | 51 +++++++++++++++++++++++-------- src/flask_sqlalchemy/model.py | 8 +++-- tests/conftest.py | 8 ++--- tests/test_model.py | 9 ++++++ 5 files changed, 58 insertions(+), 19 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 35d8d43a..4f47200f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,7 @@ Version 3.1.0 Unreleased +- Add support for the SQLAlchemy 2.x API via ``model_class`` parameter. :issue:`1140` - Remove previously deprecated code. - Pass extra keyword arguments from ``get_or_404`` to ``session.get``. :issue:`1149` diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 8821b05f..b29db83c 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -29,6 +29,23 @@ _O = t.TypeVar("_O", bound=object) # Based on sqlalchemy.orm._typing.py +# Type accepted for model_class argument +FSA_MC = t.TypeVar( + "FSA_MC", + bound=t.Union[ + Model, + sa_orm.DeclarativeMeta, + sa_orm.DeclarativeBase, + sa_orm.DeclarativeBaseNoMeta, + ], +) + + +# Type returned by make_declarative_base +class FSAModel(Model): + metadata: sa.MetaData + + class SQLAlchemy: """Integrates SQLAlchemy with Flask. This handles setting up one or more engines, associating tables and models with specific engines, and cleaning up connections and @@ -127,7 +144,7 @@ def __init__( metadata: sa.MetaData | None = None, session_options: dict[str, t.Any] | None = None, query_class: type[Query] = Query, - model_class: type[Model] | sa_orm.DeclarativeMeta = Model, + model_class: t.Type[FSA_MC] = Model, # type: ignore[assignment] engine_options: dict[str, t.Any] | None = None, add_models_to_shell: bool = True, ): @@ -448,8 +465,9 @@ def __new__( return Table def _make_declarative_base( - self, model: type[Model] | sa_orm.DeclarativeMeta - ) -> type[t.Any]: + self, + model_class: t.Type[FSA_MC], + ) -> t.Type[FSAModel]: """Create a SQLAlchemy declarative model class. The result is available as :attr:`Model`. @@ -461,7 +479,11 @@ def _make_declarative_base( :meta private: - :param model: A model base class, or an already created declarative model class. + :param model_class: A model base class, or an already created declarative model + class. + + .. versionchanged:: 3.0.4 + Added support for passing SQLAlchemy 2.x base class as model class. .. versionchanged:: 3.0 Renamed with a leading underscore, this method is internal. @@ -469,41 +491,44 @@ def _make_declarative_base( .. versionchanged:: 2.3 ``model`` can be an already created declarative model class. """ + model: t.Type[FSAModel] declarative_bases = [ b - for b in model.__bases__ + for b in model_class.__bases__ if issubclass(b, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) ] if len(declarative_bases) > 1: # raise error if more than one declarative base is found raise ValueError( "Only one declarative base can be passed to SQLAlchemy." - " Got: {}".format(model.__bases__) + " Got: {}".format(model_class.__bases__) ) elif len(declarative_bases) == 1: body = {"__fsa__": self} model = types.new_class( "FlaskSQLAlchemyBase", - (BindMixin, NameMixin, Model, *model.__bases__), + (BindMixin, NameMixin, Model, *model_class.__bases__), {"metaclass": type(declarative_bases[0])}, lambda ns: ns.update(body), ) - elif not isinstance(model, sa_orm.DeclarativeMeta): + elif not isinstance(model_class, sa.orm.DeclarativeMeta): metadata = self._make_metadata(None) model = sa_orm.declarative_base( - metadata=metadata, cls=model, name="Model", metaclass=DefaultMeta + metadata=metadata, cls=model_class, name="Model", metaclass=DefaultMeta ) + else: + model = model_class if None not in self.metadatas: # Use the model's metadata as the default metadata. - model.metadata.info["bind_key"] = None # type: ignore[union-attr] - self.metadatas[None] = model.metadata # type: ignore[union-attr] + model.metadata.info["bind_key"] = None + self.metadatas[None] = model.metadata else: # Use the passed in default metadata as the model's metadata. - model.metadata = self.metadatas[None] # type: ignore[union-attr] + model.metadata = self.metadatas[None] model.query_class = self.Query - model.query = _QueryProperty() + model.query = _QueryProperty() # type: ignore[assignment] model.__fsa__ = self return model diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 1ec98e1b..546cb831 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -98,13 +98,15 @@ class BindMixin: If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is ignored. If the ``metadata`` is the same as the parent model, it will not be set directly on the child model. + + .. versionchanged:: 3.0.4 """ __fsa__: SQLAlchemy metadata: sa.MetaData @classmethod - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls: t.Type[BindMixin], **kwargs: t.Dict[str, t.Any]) -> None: if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): bind_key = getattr(cls, "__bind_key__", None) parent_metadata = getattr(cls, "metadata", None) @@ -190,6 +192,8 @@ class NameMixin: ``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models that do not otherwise define ``__tablename__``. If a model does not define a primary key, it will not generate a name or ``__table__``, for single-table inheritance. + + .. versionchanged:: 3.0.4 """ metadata: sa.MetaData @@ -197,7 +201,7 @@ class NameMixin: __table__: sa.Table @classmethod - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls: t.Type[NameMixin], **kwargs: t.Dict[str, t.Any]) -> None: if should_set_tablename(cls): cls.__tablename__ = camel_to_snake_case(cls.__name__) diff --git a/tests/conftest.py b/tests/conftest.py index 2634dae2..b4a426c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,15 +79,15 @@ class Todo(db.Model): db.Model, (sa_orm.DeclarativeBaseNoMeta, sa_orm.DeclarativeBaseNoMeta) ): - class Todo(db.Model): + class Todo(db.Model): # type: ignore[no-redef] id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) title: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True) else: - class Todo(db.Model): - id: sa.Column = sa.Column(sa.Integer, primary_key=True) - title: sa.Column = sa.Column(sa.String) + class Todo(db.Model): # type: ignore[no-redef] + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String) with app.app_context(): db.create_all() diff --git a/tests/test_model.py b/tests/test_model.py index 80ced173..5b16da3b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -59,3 +59,12 @@ class User(db.Model): assert repr(user) == f"" db.session.flush() assert repr(user) == f"" + + +@pytest.mark.usefixtures("app_ctx") +def test_too_many_bases(app: Flask) -> None: + class Base(sa.orm.DeclarativeBase, sa.orm.DeclarativeBaseNoMeta): # type: ignore[misc] + pass + + with pytest.raises(ValueError): + SQLAlchemy(app, model_class=Base) From a8da6197506f6a4a9a64fbb517233c651808390b Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 13 Jun 2023 06:43:07 -0700 Subject: [PATCH 10/27] Update sqlalchemy version --- CHANGES.rst | 1 + pyproject.toml | 2 +- requirements/mypy.txt | 2 -- requirements/tests-min.in | 2 +- requirements/tests-min.txt | 8 ++++---- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 4f47200f..8c3e2477 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,7 @@ Version 3.1.0 Unreleased - Add support for the SQLAlchemy 2.x API via ``model_class`` parameter. :issue:`1140` +- Bump minimum version of SQLAlchemy to 2.0.16. - Remove previously deprecated code. - Pass extra keyword arguments from ``get_or_404`` to ``session.get``. :issue:`1149` diff --git a/pyproject.toml b/pyproject.toml index a55f9868..e8460c45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ requires-python = ">=3.7" dependencies = [ "flask>=2.2.5", - "sqlalchemy>=1.4.18", + "sqlalchemy>=2.0.16", ] dynamic = ["version"] diff --git a/requirements/mypy.txt b/requirements/mypy.txt index 983b2686..759cfc8f 100644 --- a/requirements/mypy.txt +++ b/requirements/mypy.txt @@ -5,8 +5,6 @@ # # pip-compile-multi # -greenlet==2.0.2 - # via sqlalchemy iniconfig==2.0.0 # via pytest mypy==1.4.1 diff --git a/requirements/tests-min.in b/requirements/tests-min.in index d7c14b07..59c9c494 100644 --- a/requirements/tests-min.in +++ b/requirements/tests-min.in @@ -1,3 +1,3 @@ flask==2.2.5 werkzeug<2.3 -sqlalchemy==1.4.18 +sqlalchemy==2.0.16 diff --git a/requirements/tests-min.txt b/requirements/tests-min.txt index 12bbffcd..a5dc4228 100644 --- a/requirements/tests-min.txt +++ b/requirements/tests-min.txt @@ -1,4 +1,4 @@ -# SHA1:46c532e5c6f44988d3c02d047503f87ea25a2f72 +# SHA1:8cae0f9c9bfdb0c4ddf46f9fb41a746cc14300a7 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -9,8 +9,6 @@ click==8.1.3 # via flask flask==2.2.5 # via -r requirements/tests-min.in -greenlet==2.0.2 - # via sqlalchemy itsdangerous==2.1.2 # via flask jinja2==3.1.2 @@ -19,8 +17,10 @@ markupsafe==2.1.3 # via # jinja2 # werkzeug -sqlalchemy==1.4.18 +sqlalchemy==2.0.16 # via -r requirements/tests-min.in +typing-extensions==4.6.3 + # via sqlalchemy werkzeug==2.2.3 # via # -r requirements/tests-min.in From d3a30c52bf1d684a7cb8070e12390e702755426b Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 13 Jun 2023 09:17:53 -0700 Subject: [PATCH 11/27] Mark the type classes as internal with an underscore in front --- src/flask_sqlalchemy/extension.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index b29db83c..c9f25106 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -30,8 +30,8 @@ # Type accepted for model_class argument -FSA_MC = t.TypeVar( - "FSA_MC", +_FSA_MC = t.TypeVar( + "_FSA_MC", bound=t.Union[ Model, sa_orm.DeclarativeMeta, @@ -42,7 +42,7 @@ # Type returned by make_declarative_base -class FSAModel(Model): +class _FSAModel(Model): metadata: sa.MetaData @@ -144,7 +144,7 @@ def __init__( metadata: sa.MetaData | None = None, session_options: dict[str, t.Any] | None = None, query_class: type[Query] = Query, - model_class: t.Type[FSA_MC] = Model, # type: ignore[assignment] + model_class: t.Type[_FSA_MC] = Model, # type: ignore[assignment] engine_options: dict[str, t.Any] | None = None, add_models_to_shell: bool = True, ): @@ -466,8 +466,8 @@ def __new__( def _make_declarative_base( self, - model_class: t.Type[FSA_MC], - ) -> t.Type[FSAModel]: + model_class: t.Type[_FSA_MC], + ) -> t.Type[_FSAModel]: """Create a SQLAlchemy declarative model class. The result is available as :attr:`Model`. @@ -491,7 +491,7 @@ def _make_declarative_base( .. versionchanged:: 2.3 ``model`` can be an already created declarative model class. """ - model: t.Type[FSAModel] + model: t.Type[_FSAModel] declarative_bases = [ b for b in model_class.__bases__ From 73e02065063bb6a2f9a6a6c6e53291215f09b944 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Mon, 3 Jul 2023 17:28:05 -0700 Subject: [PATCH 12/27] Fix tests --- src/flask_sqlalchemy/extension.py | 2 +- tests/conftest.py | 4 +-- tests/test_model.py | 2 +- tests/test_view_query.py | 41 +++++++++++++++++++++++-------- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index c9f25106..4a13056a 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -511,7 +511,7 @@ def _make_declarative_base( {"metaclass": type(declarative_bases[0])}, lambda ns: ns.update(body), ) - elif not isinstance(model_class, sa.orm.DeclarativeMeta): + elif not isinstance(model_class, sa_orm.DeclarativeMeta): metadata = self._make_metadata(None) model = sa_orm.declarative_base( metadata=metadata, cls=model_class, name="Model", metaclass=DefaultMeta diff --git a/tests/conftest.py b/tests/conftest.py index b4a426c1..0733eadd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,9 +75,7 @@ class Todo(db.Model): sa.String, nullable=True, default=None ) - elif issubclass( - db.Model, (sa_orm.DeclarativeBaseNoMeta, sa_orm.DeclarativeBaseNoMeta) - ): + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): class Todo(db.Model): # type: ignore[no-redef] id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) diff --git a/tests/test_model.py b/tests/test_model.py index 5b16da3b..1bd5f850 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -63,7 +63,7 @@ class User(db.Model): @pytest.mark.usefixtures("app_ctx") def test_too_many_bases(app: Flask) -> None: - class Base(sa.orm.DeclarativeBase, sa.orm.DeclarativeBaseNoMeta): # type: ignore[misc] + class Base(sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta): # type: ignore[misc] pass with pytest.raises(ValueError): diff --git a/tests/test_view_query.py b/tests/test_view_query.py index ddbf1bbc..c1d056c1 100644 --- a/tests/test_view_query.py +++ b/tests/test_view_query.py @@ -4,6 +4,7 @@ import pytest import sqlalchemy as sa +import sqlalchemy.orm as sa_orm from flask import Flask from werkzeug.exceptions import NotFound @@ -64,20 +65,40 @@ def test_paginate(db: SQLAlchemy, Todo: t.Any) -> None: # This test creates its own inline model so that it can use that as the type @pytest.mark.usefixtures("app_ctx") def test_view_get_or_404_typed(db: SQLAlchemy, app: Flask) -> None: - class Quiz(db.Model): - id = sa.Column(sa.Integer, primary_key=True) - topic = sa.Column(sa.String) + # Copied and pasted from conftest.py + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): + + class Todo(db.Model): + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.Integer, init=False, primary_key=True + ) + title: sa_orm.Mapped[str] = sa_orm.mapped_column( + sa.String, nullable=True, default=None + ) + + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): + + class Todo(db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True) + + else: + + class Todo(db.Model): # type: ignore[no-redef] + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String) db.create_all() - item: Quiz = Quiz(topic="Python") - db.session.add(item) + todo = Todo() + todo.title = "Python" + db.session.add(todo) db.session.commit() - result = db.get_or_404(Quiz, 1) - assert result.topic == "Python" - assert result is item + result = db.get_or_404(Todo, 1) + assert result.title == "Python" + assert result is todo if hasattr(t, "assert_type"): - t.assert_type(result, Quiz) + t.assert_type(result, Todo) with pytest.raises(NotFound): - assert db.get_or_404(Quiz, 2) + assert db.get_or_404(Todo, 2) db.drop_all() From 2c3650af7a61ee30846e693d423153ae62058cb7 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 5 Jul 2023 13:49:26 -0700 Subject: [PATCH 13/27] Addressing docs note --- docs/quickstart.rst | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 6cee5db2..1844a4dc 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -44,6 +44,15 @@ Initialize the Extension ------------------------ First create the ``db`` object using the ``SQLAlchemy`` constructor. +The initialization step depends on which version of ``SQLAlchemy`` you're using. +This extension supports both SQLAlchemy 1 and 2, but defaults to SQLAlchemy 1. + +.. _sqlalchemy1-initialization: + +Using the SQLAlchemy 1 API +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To use the SQLAlchemy 1.x API, you do not need to pass any arguments to the ``SQLAlchemy`` constructor. .. code-block:: python @@ -53,10 +62,12 @@ First create the ``db`` object using the ``SQLAlchemy`` constructor. db = SQLAlchemy() +.. _sqlalchemy2-initialization: -By default, this extension assumes that you are using the SQLAlchemy 1.x API for defining models. +Using the SQLAlchemy 2 API +^^^^^^^^^^^^^^^^^^^^^^^^^^ -To use the new SQLAlchemy 2.x API, pass a subclass of either ``DeclarativeBase`` or ``DeclarativeBaseNoMeta`` +To use the new SQLAlchemy 2.x API, pass a subclass of either `DeclarativeBase`_ or `DeclarativeBaseNoMeta`_ to the constructor. .. code-block:: python @@ -70,6 +81,25 @@ to the constructor. db = SQLAlchemy(model_class=Base) +If desired, you can enable `SQLAlchemy's native support for data classes`_ +by adding `MappedAsDataclass` as an additional parent class. + +.. code-block:: python + + from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass + + class Base(DeclarativeBase, MappedAsDataclass): + pass + + db = SQLAlchemy(model_class=Base) + +.. _DeclarativeBase: https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.DeclarativeBase +.. _DeclarativeBaseNoMeta: https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.DeclarativeBaseNoMeta +.. _SQLAlchemy's native support for data classes: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#native-support-for-dataclasses-mapped-as-orm-models + +About the ``SQLAlchemy`` object +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + Once constructed, the ``db`` object gives you access to the :attr:`db.Model <.SQLAlchemy.Model>` class to define models, and the :attr:`db.session <.SQLAlchemy.session>` to execute queries. @@ -122,7 +152,7 @@ The table name ``"user"`` will automatically be assigned to the model's table. It's also possible to use the SQLAlchemy 2.x style of defining models, as long as you initialized the extension with an appropriate 2.x model base class -as described above. +as described in :ref:`sqlalchemy2-initialization`. .. code-block:: python From 7610ea7b12a2c1d307d5a58c5c5493e4c568da9e Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 5 Jul 2023 15:51:10 -0700 Subject: [PATCH 14/27] Add constraint of les than 2.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e8460c45..ca0b9420 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ requires-python = ">=3.7" dependencies = [ "flask>=2.2.5", - "sqlalchemy>=2.0.16", + "sqlalchemy>=2.0.16,<2.1", ] dynamic = ["version"] From 8fdc47acf862a000747c4c593cce7f21ac00e42e Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 5 Jul 2023 17:14:58 -0700 Subject: [PATCH 15/27] Remove version cap per Davids comment --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ca0b9420..e8460c45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ requires-python = ">=3.7" dependencies = [ "flask>=2.2.5", - "sqlalchemy>=2.0.16,<2.1", + "sqlalchemy>=2.0.16", ] dynamic = ["version"] From 1bbd722065acf10fd8f7747bfb793afb98ff636d Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 11 Jul 2023 14:16:51 -0700 Subject: [PATCH 16/27] Add disable_autonaming --- docs/customizing.rst | 68 +++++-------------------------- src/flask_sqlalchemy/extension.py | 24 +++++++++-- src/flask_sqlalchemy/model.py | 6 +++ tests/test_model.py | 24 +++++++++++ 4 files changed, 61 insertions(+), 61 deletions(-) diff --git a/docs/customizing.rst b/docs/customizing.rst index c9f76299..7540a729 100644 --- a/docs/customizing.rst +++ b/docs/customizing.rst @@ -160,71 +160,25 @@ To customize only ``session.query``, pass the ``query_cls`` key to the db = SQLAlchemy(session_options={"query_cls": GetOrQuery}) -Model Metaclass ---------------- - -.. warning:: - Metaclasses are an advanced topic, and you probably don't need to customize them to - achieve what you want. It is mainly documented here to show how to disable table - name generation. - -The model metaclass is responsible for setting up the SQLAlchemy internals when defining -model subclasses. Flask-SQLAlchemy adds some extra behaviors through mixins; its default -metaclass, :class:`~.DefaultMeta`, inherits them all. - -- :class:`.BindMetaMixin`: ``__bind_key__`` sets the bind to use for the model. -- :class:`.NameMetaMixin`: If the model does not specify a ``__tablename__`` but does - specify a primary key, a name is automatically generated. - -You can add your own behaviors by defining your own metaclass and creating the -declarative base yourself. Be sure to still inherit from the mixins you want (or just -inherit from the default metaclass). - -Passing a declarative base class instead of a simple model base class to ``model_class`` -will cause Flask-SQLAlchemy to use this base instead of constructing one with the -default metaclass. - -.. code-block:: python - - from sqlalchemy.orm import declarative_base - from flask_sqlalchemy import SQLAlchemy - from flask_sqlalchemy.model import DefaultMeta, Model - - class CustomMeta(DefaultMeta): - def __init__(cls, name, bases, d): - # custom class setup could go here - - # be sure to call super - super(CustomMeta, cls).__init__(name, bases, d) - - # custom class-only methods could go here - - CustomModel = declarative_base(cls=Model, metaclass=CustomMeta, name="Model") - db = SQLAlchemy(model_class=CustomModel) - -You can also pass whatever other arguments you want to -:func:`~sqlalchemy.orm.declarative_base` to customize the base class. - - Disabling Table Name Generation -``````````````````````````````` +------------------------------- Some projects prefer to set each model's ``__tablename__`` manually rather than relying on Flask-SQLAlchemy's detection and generation. The simple way to achieve that is to set each ``__tablename__`` and not modify the base class. However, the table name -generation can be disabled by defining a custom metaclass with only the -``BindMetaMixin`` and not the ``NameMetaMixin``. +generation can be disabled by setting `disable_autonaming=True` in the `SQLAlchemy` constructor. + +Example code using the SQLAlchemy 1.x (legacy) API: .. code-block:: python - from sqlalchemy.orm import DeclarativeMeta, declarative_base - from flask_sqlalchemy.model import BindMetaMixin, Model + db = SQLAlchemy(app, disable_autonaming=True) - class NoNameMeta(BindMetaMixin, DeclarativeMeta): - pass +Example code using the SQLAlchemy 2.x declarative base: - CustomModel = declarative_base(cls=Model, metaclass=NoNameMeta, name="Model") - db = SQLAlchemy(model_class=CustomModel) +.. code-block:: python + + class Base(sa_orm.DeclarativeBase): + pass -This creates a base that still supports the ``__bind_key__`` feature but does not -generate table names. + db = SQLAlchemy(app, model_class=Base, disable_autonaming=True) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 4a13056a..9cddfa65 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -17,6 +17,7 @@ from .model import _QueryProperty from .model import BindMixin from .model import DefaultMeta +from .model import DefaultMetaNoName from .model import Model from .model import NameMixin from .pagination import Pagination @@ -86,6 +87,10 @@ class SQLAlchemy: :param add_models_to_shell: Add the ``db`` instance and all model classes to ``flask shell``. + .. versionchanged:: 3.1.0 + Added the ``disable_autonaming`` parameter and changed ``model_class`` parameter + to accept a SQLAlchemy 2.0-style declarative base subclass. + .. versionchanged:: 3.0 An active Flask application context is always required to access ``session`` and ``engine``. @@ -147,6 +152,7 @@ def __init__( model_class: t.Type[_FSA_MC] = Model, # type: ignore[assignment] engine_options: dict[str, t.Any] | None = None, add_models_to_shell: bool = True, + disable_autonaming: bool = False, ): if session_options is None: session_options = {} @@ -207,7 +213,9 @@ def __init__( This is a subclass of SQLAlchemy's ``Table`` rather than a function. """ - self.Model = self._make_declarative_base(model_class) + self.Model = self._make_declarative_base( + model_class, disable_autonaming=disable_autonaming + ) """A SQLAlchemy declarative model class. Subclass this to define database models. @@ -467,6 +475,7 @@ def __new__( def _make_declarative_base( self, model_class: t.Type[_FSA_MC], + disable_autonaming: bool = False, ) -> t.Type[_FSAModel]: """Create a SQLAlchemy declarative model class. The result is available as :attr:`Model`. @@ -482,8 +491,11 @@ def _make_declarative_base( :param model_class: A model base class, or an already created declarative model class. - .. versionchanged:: 3.0.4 + :param disable_autonaming: Turns off automatic tablename generation in models. + + .. versionchanged:: 3.1.0 Added support for passing SQLAlchemy 2.x base class as model class. + Added optional ``disable_autonaming`` parameter. .. versionchanged:: 3.0 Renamed with a leading underscore, this method is internal. @@ -505,16 +517,20 @@ def _make_declarative_base( ) elif len(declarative_bases) == 1: body = {"__fsa__": self} + mixin_classes = [BindMixin, NameMixin, Model] + if disable_autonaming: + mixin_classes.remove(NameMixin) model = types.new_class( "FlaskSQLAlchemyBase", - (BindMixin, NameMixin, Model, *model_class.__bases__), + (*mixin_classes, *model_class.__bases__), {"metaclass": type(declarative_bases[0])}, lambda ns: ns.update(body), ) elif not isinstance(model_class, sa_orm.DeclarativeMeta): metadata = self._make_metadata(None) + metaclass = DefaultMetaNoName if disable_autonaming else DefaultMeta model = sa_orm.declarative_base( - metadata=metadata, cls=model_class, name="Model", metaclass=DefaultMeta + metadata=metadata, cls=model_class, name="Model", metaclass=metaclass ) else: model = model_class diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 546cb831..d77bd126 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -314,3 +314,9 @@ class DefaultMeta(BindMetaMixin, NameMetaMixin, sa_orm.DeclarativeMeta): """SQLAlchemy declarative metaclass that provides ``__bind_key__`` and ``__tablename__`` support. """ + + +class DefaultMetaNoName(BindMetaMixin, sa_orm.DeclarativeMeta): + """SQLAlchemy declarative metaclass that provides ``__bind_key__`` and + ``__tablename__`` support. + """ diff --git a/tests/test_model.py b/tests/test_model.py index 1bd5f850..1b316f0c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,6 +4,7 @@ import pytest import sqlalchemy as sa +import sqlalchemy.exc as sa_exc import sqlalchemy.orm as sa_orm from flask import Flask @@ -68,3 +69,26 @@ class Base(sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta): # type: ignor with pytest.raises(ValueError): SQLAlchemy(app, model_class=Base) + + +@pytest.mark.usefixtures("app_ctx") +def test_disable_autonaming_true_sql1(app: Flask) -> None: + db = SQLAlchemy(app, disable_autonaming=True) + + with pytest.raises(sa_exc.InvalidRequestError): + + class User(db.Model): + id = sa.Column(sa.Integer, primary_key=True) + + +@pytest.mark.usefixtures("app_ctx") +def test_disable_autonaming_true_sql2(app: Flask) -> None: + class Base(sa_orm.DeclarativeBase): + pass + + db = SQLAlchemy(app, model_class=Base, disable_autonaming=True) + + with pytest.raises(sa_exc.InvalidRequestError): + + class User(db.Model): + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) From b052dd6f3cd9d2427b498f80f2c298c8813e0cbe Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Sun, 20 Aug 2023 14:30:04 -0700 Subject: [PATCH 17/27] Update versionchanged for new mixins --- src/flask_sqlalchemy/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index d77bd126..e3dac5e7 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -99,7 +99,7 @@ class BindMixin: ignored. If the ``metadata`` is the same as the parent model, it will not be set directly on the child model. - .. versionchanged:: 3.0.4 + .. versionchanged:: 3.1.0 """ __fsa__: SQLAlchemy @@ -193,7 +193,7 @@ class NameMixin: that do not otherwise define ``__tablename__``. If a model does not define a primary key, it will not generate a name or ``__table__``, for single-table inheritance. - .. versionchanged:: 3.0.4 + .. versionchanged:: 3.1.0 """ metadata: sa.MetaData From d8c31ca9eb40391576168d88dfd3c65c01d7841f Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 22 Aug 2023 11:59:11 -0700 Subject: [PATCH 18/27] Document refactor to emphasize 2.x, fix for bind key and metadata for 2.x --- docs/config.rst | 25 -------- docs/legacy-quickstart.rst | 95 +++++++++++++++++++++++++++++++ docs/models.rst | 88 +++++++++++++++++++++------- docs/quickstart.rst | 68 ++++------------------ src/flask_sqlalchemy/extension.py | 34 ++++++++--- src/flask_sqlalchemy/model.py | 15 +++-- tests/conftest.py | 24 ++++++-- tests/test_metadata.py | 63 ++++++++++++++++---- 8 files changed, 283 insertions(+), 129 deletions(-) create mode 100644 docs/legacy-quickstart.rst diff --git a/docs/config.rst b/docs/config.rst index a03276b6..21a27ebc 100644 --- a/docs/config.rst +++ b/docs/config.rst @@ -151,31 +151,6 @@ only need to use :data:`SQLALCHEMY_DATABASE_URI` and :data:`SQLALCHEMY_ENGINE_OP in that engine's options. -Using custom MetaData and naming conventions --------------------------------------------- - -You can optionally construct the :class:`.SQLAlchemy` object with a custom -:class:`~sqlalchemy.schema.MetaData` object. This allows you to specify a custom -constraint `naming convention`_. This makes constraint names consistent and predictable, -useful when using migrations, as described by `Alembic`_. - -.. code-block:: python - - from sqlalchemy import MetaData - from flask_sqlalchemy import SQLAlchemy - - db = SQLAlchemy(metadata=MetaData(naming_convention={ - "ix": 'ix_%(column_0_label)s', - "uq": "uq_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(table_name)s_%(constraint_name)s", - "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s" - })) - -.. _naming convention: https://docs.sqlalchemy.org/core/constraints.html#constraint-naming-conventions -.. _Alembic: https://alembic.sqlalchemy.org/en/latest/naming.html - - Timeouts -------- diff --git a/docs/legacy-quickstart.rst b/docs/legacy-quickstart.rst new file mode 100644 index 00000000..0d09b9e2 --- /dev/null +++ b/docs/legacy-quickstart.rst @@ -0,0 +1,95 @@ + +:orphan: + +Legacy Quickstart +====================== + +.. warning:: + This guide shows you how to initialize the extension and define models + when using the SQLAlchemy 1.x style of ORM model classes. We encourage you to + upgrade to `SQLAlchemy 2.x`_ to take advantage of the new typed model classes. + +.. _SQLAlchemy 2.x: https://docs.sqlalchemy.org/en/20/orm/quickstart.html + +Initialize the Extension +------------------------ + +First create the ``db`` object using the ``SQLAlchemy`` constructor. + +When using the SQLAlchemy 1.x API, you do not need to pass any arguments to the ``SQLAlchemy`` constructor. +A declarative base class will be created behind the scenes for you. + +.. code-block:: python + + from flask import Flask + from flask_sqlalchemy import SQLAlchemy + from sqlalchemy.orm import DeclarativeBase + + db = SQLAlchemy() + + +Using custom MetaData and naming conventions +-------------------------------------------- + +You can optionally construct the :class:`.SQLAlchemy` object with a custom +:class:`~sqlalchemy.schema.MetaData` object. This allows you to specify a custom +constraint `naming convention`_. This makes constraint names consistent and predictable, +useful when using migrations, as described by `Alembic`_. + +.. code-block:: python + + from sqlalchemy import MetaData + from flask_sqlalchemy import SQLAlchemy + + db = SQLAlchemy(metadata=MetaData(naming_convention={ + "ix": 'ix_%(column_0_label)s', + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s" + })) + +.. _naming convention: https://docs.sqlalchemy.org/core/constraints.html#constraint-naming-conventions +.. _Alembic: https://alembic.sqlalchemy.org/en/latest/naming.html + + + +Define Models +------------- + +Subclass ``db.Model`` to define a model class. This is a SQLAlchemy declarative base +class, it will take ``Column`` attributes and create a table. + +.. code-block:: python + + class User(db.Model): + id = db.Column(db.Integer, primary_key=True) + username = db.Column(db.String, unique=True, nullable=False) + email = db.Column(db.String) + +For convenience, the extension object provides access to names in the ``sqlalchemy`` and +``sqlalchemy.orm`` modules. So you can use ``db.Column`` instead of importing and using +``sqlalchemy.Column``, although the two are equivalent. + +Unlike plain SQLAlchemy, Flask-SQLAlchemy's model will automatically generate a table name +if ``__tablename__`` is not set and a primary key column is defined. +The table name ``"user"`` will automatically be assigned to the model's table. + + +Create the Tables +----------------- + +Defining a model does not create it in the database. Use :meth:`~.SQLAlchemy.create_all` +to create the models and tables after defining them. If you define models in submodules, +you must import them so that SQLAlchemy knows about them before calling ``create_all``. + +.. code-block:: python + + with app.app_context(): + db.create_all() + +Querying the Data +----------------- + +You can query the data the same way regardless of SQLAlchemy version. +See :doc:`queries` for more information about queries. diff --git a/docs/models.rst b/docs/models.rst index 56a7b16a..f2c6a159 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -4,35 +4,82 @@ Models and Tables Use the ``db.Model`` class to define models, or the ``db.Table`` class to create tables. Both handle Flask-SQLAlchemy's bind keys to associate with a specific engine. +Initializing the Base Class +--------------------------- -Defining Models ---------------- +``SQLAlchemy`` 2.x offers several possible base classes for your models: +`DeclarativeBase`_ or `DeclarativeBaseNoMeta`_. -See SQLAlchemy's `declarative documentation`_ for full information about defining model -classes declaratively. +Create a subclass of one of those classes: -.. _declarative documentation: https://docs.sqlalchemy.org/orm/declarative_tables.html +.. code-block:: python -Subclass ``db.Model`` to create a model class. This is a SQLAlchemy declarative base -class, it will take ``Column`` attributes and create a table. Unlike plain SQLAlchemy, -Flask-SQLAlchemy's model will automatically generate a table name if ``__tablename__`` -is not set and a primary key column is defined. + from flask import Flask + from flask_sqlalchemy import SQLAlchemy + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass + +.. _DeclarativeBase: https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.DeclarativeBase +.. _DeclarativeBaseNoMeta: https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.DeclarativeBaseNoMeta + +If desired, you can enable `SQLAlchemy's native support for data classes`_ +by adding `MappedAsDataclass` as an additional parent class. .. code-block:: python - import sqlalchemy as sa + from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass - class User(db.Model): - id = sa.Column(sa.Integer, primary_key=True) - type = sa.Column(sa.String) + class Base(DeclarativeBase, MappedAsDataclass): + pass + +.. _SQLAlchemy's native support for data classes: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#native-support-for-dataclasses-mapped-as-orm-models + + +You can optionally construct the :class:`.SQLAlchemy` object with a custom +:class:`~sqlalchemy.schema.MetaData` object. This allows you to specify a custom +constraint `naming convention`_. This makes constraint names consistent and predictable, +useful when using migrations, as described by `Alembic`_. + +.. code-block:: python + + from sqlalchemy import MetaData -For convenience, the extension object provides access to names in the ``sqlalchemy`` and -``sqlalchemy.orm`` modules. So you can use ``db.Column`` instead of importing and using -``sqlalchemy.Column``, although the two are equivalent. + class Base(DeclarativeBase): + metadata = MetaData(naming_convention={ + "ix": 'ix_%(column_0_label)s', + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s" + }) -It's also possible to use the SQLAlchemy 2.x style of defining models, -as long as you initialized the extension with an appropriate 2.x model base class -(as described in the quickstart). +.. _naming convention: https://docs.sqlalchemy.org/core/constraints.html#constraint-naming-conventions +.. _Alembic: https://alembic.sqlalchemy.org/en/latest/naming.html + + +Initialize the Extension +------------------------ + +Once you've defined a base class, create the ``db`` object using the ``SQLAlchemy`` constructor. + +.. code-block:: python + + db = SQLAlchemy(model_class=Base) + + +Defining Models +--------------- + +See SQLAlchemy's `declarative documentation`_ for full information about defining model +classes declaratively. + +.. _declarative documentation: https://docs.sqlalchemy.org/en/20/orm/declarative_tables.html + +Subclass ``db.Model`` to create a model class. Unlike plain SQLAlchemy, +Flask-SQLAlchemy's model will automatically generate a table name if ``__tablename__`` +is not set and a primary key column is defined. .. code-block:: python @@ -43,6 +90,7 @@ as long as you initialized the extension with an appropriate 2.x model base clas username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) email: Mapped[str] = mapped_column(db.String) + Defining a model does not create it in the database. Use :meth:`~.SQLAlchemy.create_all` to create the models and tables after defining them. If you define models in submodules, you must import them so that SQLAlchemy knows about them before calling ``create_all``. @@ -59,7 +107,7 @@ Defining Tables See SQLAlchemy's `table documentation`_ for full information about defining table objects. -.. _table documentation: https://docs.sqlalchemy.org/core/metadata.html +.. _table documentation: https://docs.sqlalchemy.org/en/20/core/metadata.html Create instances of ``db.Table`` to define tables. The class takes a table name, then any columns and other table parts such as columns and constraints. Unlike plain diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 1844a4dc..b32b8660 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -21,11 +21,14 @@ Flask-SQLAlchemy is a wrapper around SQLAlchemy. You should follow the `SQLAlchemy Tutorial`_ to learn about how to use it, and consult its documentation for detailed information about its features. These docs show how to set up Flask-SQLAlchemy itself, not how to use SQLAlchemy. Flask-SQLAlchemy sets up the -engine, declarative model class, and scoped session automatically, so you can skip those +engine and scoped session automatically, so you can skip those parts of the SQLAlchemy tutorial. -.. _SQLAlchemy Tutorial: https://docs.sqlalchemy.org/tutorial/index.html +.. _SQLAlchemy Tutorial: https://docs.sqlalchemy.org/en/20/tutorial/index.html +This guide assumes you are using SQLAlchemy 2.x, which has a new API for defining models +and better support for Python type hints and data classes. If you are using SQLAlchemy 1.x, +see :doc:`legacy-quickstart`. Installation ------------ @@ -44,30 +47,8 @@ Initialize the Extension ------------------------ First create the ``db`` object using the ``SQLAlchemy`` constructor. -The initialization step depends on which version of ``SQLAlchemy`` you're using. -This extension supports both SQLAlchemy 1 and 2, but defaults to SQLAlchemy 1. -.. _sqlalchemy1-initialization: - -Using the SQLAlchemy 1 API -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To use the SQLAlchemy 1.x API, you do not need to pass any arguments to the ``SQLAlchemy`` constructor. - -.. code-block:: python - - from flask import Flask - from flask_sqlalchemy import SQLAlchemy - from sqlalchemy.orm import DeclarativeBase - - db = SQLAlchemy() - -.. _sqlalchemy2-initialization: - -Using the SQLAlchemy 2 API -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To use the new SQLAlchemy 2.x API, pass a subclass of either `DeclarativeBase`_ or `DeclarativeBaseNoMeta`_ +Pass a subclass of either `DeclarativeBase`_ or `DeclarativeBaseNoMeta`_ to the constructor. .. code-block:: python @@ -81,21 +62,11 @@ to the constructor. db = SQLAlchemy(model_class=Base) -If desired, you can enable `SQLAlchemy's native support for data classes`_ -by adding `MappedAsDataclass` as an additional parent class. -.. code-block:: python - - from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass - - class Base(DeclarativeBase, MappedAsDataclass): - pass - - db = SQLAlchemy(model_class=Base) +Learn more about customizing the base model class in :doc:`models`. .. _DeclarativeBase: https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.DeclarativeBase .. _DeclarativeBaseNoMeta: https://docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.DeclarativeBaseNoMeta -.. _SQLAlchemy's native support for data classes: https://docs.sqlalchemy.org/en/20/changelog/whatsnew_20.html#native-support-for-dataclasses-mapped-as-orm-models About the ``SQLAlchemy`` object ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -134,34 +105,19 @@ keys are used. Define Models ------------- -Subclass ``db.Model`` to define a model class. The ``db`` object makes the names in -``sqlalchemy`` and ``sqlalchemy.orm`` available for convenience, such as ``db.Column``. +Subclass ``db.Model`` to define a model class. The model will generate a table name by converting the ``CamelCase`` class name to ``snake_case``. -This example uses the SQLAlchemy 1.x style of defining models: - -.. code-block:: python - - class User(db.Model): - id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String, unique=True, nullable=False) - email = db.Column(db.String) - -The table name ``"user"`` will automatically be assigned to the model's table. - -It's also possible to use the SQLAlchemy 2.x style of defining models, -as long as you initialized the extension with an appropriate 2.x model base class -as described in :ref:`sqlalchemy2-initialization`. - .. code-block:: python + from sqlalchemy import Integer, String from sqlalchemy.orm import Mapped, mapped_column class User(db.Model): - id: Mapped[int] = mapped_column(db.Integer, primary_key=True) - username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) - email: Mapped[str] = mapped_column(db.String) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + username: Mapped[str] = mapped_column(String, unique=True, nullable=False) + email: Mapped[str] = mapped_column(String) See :doc:`models` for more information about defining and creating models and tables. diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 9cddfa65..9e6e4e50 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -3,6 +3,7 @@ import os import types import typing as t +import warnings from weakref import WeakKeyDictionary import sqlalchemy as sa @@ -47,6 +48,14 @@ class _FSAModel(Model): metadata: sa.MetaData +def _get_2x_declarative_bases(model_class: t.Type[_FSA_MC]) -> list[t.Type]: + return [ + b + for b in model_class.__bases__ + if issubclass(b, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) + ] + + class SQLAlchemy: """Integrates SQLAlchemy with Flask. This handles setting up one or more engines, associating tables and models with specific engines, and cleaning up connections and @@ -72,6 +81,10 @@ class SQLAlchemy: :param app: Call :meth:`init_app` on this Flask application now. :param metadata: Use this as the default :class:`sqlalchemy.schema.MetaData`. Useful for setting a naming convention. + .. deprecated:: 3.1.0 + This parameter can still be used in conjunction with SQLAlchemy 1.x classes, + but is ignored when using SQLAlchemy 2.x style of declarative classes. + Instead, specify metadata on your Base class. :param session_options: Arguments used by :attr:`session` to create each session instance. A ``scopefunc`` key will be passed to the scoped session, not the session instance. See :class:`sqlalchemy.orm.sessionmaker` for a list of @@ -194,8 +207,17 @@ def __init__( """ if metadata is not None: - metadata.info["bind_key"] = None - self.metadatas[None] = metadata + if len(_get_2x_declarative_bases(model_class)) > 0: + warnings.warn( + "When using SQLAlchemy 2.x style of declarative classes," + " the `metadata` should be an attribute of the base class." + "The metadata passed into SQLAlchemy() is ignored.", + DeprecationWarning, + stacklevel=2, + ) + else: + metadata.info["bind_key"] = None + self.metadatas[None] = metadata self.Table = self._make_table_class() """A :class:`sqlalchemy.schema.Table` class that chooses a metadata @@ -504,11 +526,7 @@ def _make_declarative_base( ``model`` can be an already created declarative model class. """ model: t.Type[_FSAModel] - declarative_bases = [ - b - for b in model_class.__bases__ - if issubclass(b, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) - ] + declarative_bases = _get_2x_declarative_bases(model_class) if len(declarative_bases) > 1: # raise error if more than one declarative base is found raise ValueError( @@ -516,7 +534,7 @@ def _make_declarative_base( " Got: {}".format(model_class.__bases__) ) elif len(declarative_bases) == 1: - body = {"__fsa__": self} + body = {"__fsa__": self, "metadata": model_class.metadata} mixin_classes = [BindMixin, NameMixin, Model] if disable_autonaming: mixin_classes.remove(NameMixin) diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index e3dac5e7..d1989a66 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -95,8 +95,12 @@ def __init__( class BindMixin: """DeclarativeBase mixin to set a model's ``metadata`` based on ``__bind_key__``. - If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is - ignored. If the ``metadata`` is the same as the parent model, it will not be set + If no ``__bind_key__`` is specified, the model will use the default metadata + provided by ``DeclarativeBase`` or ``DeclarativeBaseNoMeta``. + If the model doesn't set ``metadata`` or ``__table__`` directly + and does set ``__bind_key__``, the model will use the metadata + for the specified bind key. + If the ``metadata`` is the same as the parent model, it will not be set directly on the child model. .. versionchanged:: 3.1.0 @@ -107,7 +111,9 @@ class BindMixin: @classmethod def __init_subclass__(cls: t.Type[BindMixin], **kwargs: t.Dict[str, t.Any]) -> None: - if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__): + if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__) and hasattr( + cls, "__bind_key__" + ): bind_key = getattr(cls, "__bind_key__", None) parent_metadata = getattr(cls, "metadata", None) metadata = cls.__fsa__._make_metadata(bind_key) @@ -270,11 +276,10 @@ def should_set_tablename(cls: type) -> bool: Later, ``__table_cls__`` will determine if the model looks like single or joined-table inheritance. If no primary key is found, the name will be unset. """ - uses_2pt0 = issubclass(cls, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) if ( cls.__dict__.get("__abstract__", False) or ( - not uses_2pt0 + not issubclass(cls, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) and not any(isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:]) ) or any( diff --git a/tests/conftest.py b/tests/conftest.py index 0733eadd..42eda133 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,24 +27,30 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: yield ctx +# Each test that uses the db or model_class fixture will be tested against +# each of these possible base class setups. +# The first one is None, which will trigger creation of a SQLAlchemy 1.x base. +# The remaining four are used to create SQLAlchemy 2.x bases. +# We defer creation of those classes until the fixture, +# so that each test gets a fresh class with its own metadata. test_classes = [ None, - types.new_class( + ( "BaseDeclarativeBase", (sa_orm.DeclarativeBase,), {"metaclass": type(sa_orm.DeclarativeBase)}, ), - types.new_class( + ( "BaseDataclassDeclarativeBase", (sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase), {"metaclass": type(sa_orm.DeclarativeBase)}, ), - types.new_class( + ( "BaseDeclarativeBaseNoMeta", (sa_orm.DeclarativeBaseNoMeta,), {"metaclass": type(sa_orm.DeclarativeBaseNoMeta)}, ), - types.new_class( + ( "BaseDataclassDeclarativeBaseNoMeta", ( sa_orm.MappedAsDataclass, @@ -58,11 +64,19 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: @pytest.fixture(params=test_classes) def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy: if request.param is not None: - return SQLAlchemy(app, model_class=request.param) + return SQLAlchemy(app, model_class=types.new_class(*request.param)) else: return SQLAlchemy(app) +@pytest.fixture(params=test_classes) +def model_class(request: pytest.FixtureRequest) -> t.Any: + if request.param is not None: + return types.new_class(*request.param) + else: + return None + + @pytest.fixture def Todo(app: Flask, db: SQLAlchemy) -> t.Generator[t.Any, None, None]: if issubclass(db.Model, (sa_orm.MappedAsDataclass)): diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 88cac794..26ad5358 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + import pytest import sqlalchemy as sa import sqlalchemy.exc as sa_exc @@ -17,7 +19,7 @@ def test_default_metadata(db: SQLAlchemy) -> None: assert db.Model.metadata is db.metadata -def test_custom_metadata() -> None: +def test_custom_metadata_1x() -> None: metadata = sa.MetaData() db = SQLAlchemy(metadata=metadata) assert db.metadata is metadata @@ -25,15 +27,47 @@ def test_custom_metadata() -> None: assert db.Model.metadata is db.metadata -def test_metadata_from_custom_model() -> None: - base = sa_orm.declarative_base(cls=Model, metaclass=DefaultMeta) +def test_custom_metadata_2x_wrongway() -> None: + custom_metadata = sa.MetaData() + + class Base(sa_orm.DeclarativeBase): + pass + + with pytest.deprecated_call(): + db = SQLAlchemy(model_class=Base, metadata=custom_metadata) + + assert db.metadata is Base.metadata + assert db.metadata.info["bind_key"] is None + assert db.Model.metadata is db.metadata + + +def test_custom_metadata_2x() -> None: + custom_metadata = sa.MetaData() + + class Base(sa_orm.DeclarativeBase): + metadata = custom_metadata + + db = SQLAlchemy(model_class=Base) + + assert db.metadata is custom_metadata + assert db.metadata.info["bind_key"] is None + assert db.Model.metadata is db.metadata + + +def test_metadata_from_custom_model(model_class: t.Any) -> None: + if model_class is not None: + # In 2.x, SQLAlchemy creates the metadata attribute + base = model_class + else: + # For 1.x, our extension creates the metadata attribute + base = sa_orm.declarative_base(cls=Model, metaclass=DefaultMeta) metadata = base.metadata db = SQLAlchemy(model_class=base) assert db.Model.metadata is metadata assert db.Model.metadata is db.metadata -def test_custom_metadata_overrides_custom_model() -> None: +def test_custom_metadata_overrides_custom_model_legacy() -> None: base = sa_orm.declarative_base(cls=Model, metaclass=DefaultMeta) metadata = sa.MetaData() db = SQLAlchemy(model_class=base, metadata=metadata) @@ -41,18 +75,27 @@ def test_custom_metadata_overrides_custom_model() -> None: assert db.Model.metadata is db.metadata -def test_metadata_per_bind(app: Flask) -> None: +def test_metadata_per_bind(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + if model_class is not None: + db = SQLAlchemy(app, model_class=model_class) + else: + db = SQLAlchemy(app) assert db.metadatas["a"] is not db.metadata assert db.metadatas["a"].info["bind_key"] == "a" -def test_copy_naming_convention(app: Flask) -> None: +def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy( - app, metadata=sa.MetaData(naming_convention={"pk": "spk_%(table_name)s"}) - ) + if model_class is not None: + model_class.metadata = sa.MetaData( + naming_convention={"pk": "spk_%(table_name)s"} + ) + db = SQLAlchemy(app, model_class=model_class) + else: + db = SQLAlchemy( + app, metadata=sa.MetaData(naming_convention={"pk": "spk_%(table_name)s"}) + ) assert db.metadata.naming_convention["pk"] == "spk_%(table_name)s" assert db.metadatas["a"].naming_convention == db.metadata.naming_convention From a89485baa289a92c8639b291a1bf76d646ff7e14 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 22 Aug 2023 12:12:33 -0700 Subject: [PATCH 19/27] Parameterize more tests for commit --- tests/conftest.py | 9 ++++---- tests/test_engine.py | 49 +++++++++++++++++++++++------------------- tests/test_metadata.py | 9 +++----- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 42eda133..d4ab92f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ from flask.ctx import AppContext from flask_sqlalchemy import SQLAlchemy +from flask_sqlalchemy.model import Model @pytest.fixture @@ -34,7 +35,7 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: # We defer creation of those classes until the fixture, # so that each test gets a fresh class with its own metadata. test_classes = [ - None, + Model, ( "BaseDeclarativeBase", (sa_orm.DeclarativeBase,), @@ -63,7 +64,7 @@ def app_ctx(app: Flask) -> t.Generator[AppContext, None, None]: @pytest.fixture(params=test_classes) def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy: - if request.param is not None: + if request.param is not Model: return SQLAlchemy(app, model_class=types.new_class(*request.param)) else: return SQLAlchemy(app) @@ -71,10 +72,10 @@ def db(app: Flask, request: pytest.FixtureRequest) -> SQLAlchemy: @pytest.fixture(params=test_classes) def model_class(request: pytest.FixtureRequest) -> t.Any: - if request.param is not None: + if request.param is not Model: return types.new_class(*request.param) else: - return None + return request.param @pytest.fixture diff --git a/tests/test_engine.py b/tests/test_engine.py index 37d9d2e9..0e88d5e3 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import os.path +import typing as t import unittest.mock import pytest @@ -20,24 +21,24 @@ def test_default_engine(app: Flask, db: SQLAlchemy) -> None: @pytest.mark.usefixtures("app_ctx") -def test_engine_per_bind(app: Flask) -> None: +def test_engine_per_bind(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) assert db.engines["a"] is not db.engine @pytest.mark.usefixtures("app_ctx") -def test_config_engine_options(app: Flask) -> None: +def test_config_engine_options(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"echo": True} - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) assert db.engine.echo @pytest.mark.usefixtures("app_ctx") -def test_init_engine_options(app: Flask) -> None: +def test_init_engine_options(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_ENGINE_OPTIONS"] = {"echo": False} app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app, engine_options={"echo": True}) + db = SQLAlchemy(app, engine_options={"echo": True}, model_class=model_class) # init is default assert db.engines["a"].echo # config overrides init @@ -45,9 +46,9 @@ def test_init_engine_options(app: Flask) -> None: @pytest.mark.usefixtures("app_ctx") -def test_config_echo(app: Flask) -> None: +def test_config_echo(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_ECHO"] = True - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) assert db.engine.echo assert db.engine.pool.echo @@ -62,35 +63,35 @@ def test_config_echo(app: Flask) -> None: {"url": sa.engine.URL.create("sqlite")}, ], ) -def test_url_type(app: Flask, value: str | sa.engine.URL) -> None: +def test_url_type(app: Flask, model_class: t.Any, value: str | sa.engine.URL) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": value} - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) assert str(db.engines["a"].url) == "sqlite://" -def test_no_binds_error(app: Flask) -> None: +def test_no_binds_error(app: Flask, model_class: t.Any) -> None: del app.config["SQLALCHEMY_DATABASE_URI"] with pytest.raises(RuntimeError) as info: - SQLAlchemy(app) + SQLAlchemy(app, model_class=model_class) e = "Either 'SQLALCHEMY_DATABASE_URI' or 'SQLALCHEMY_BINDS' must be set." assert str(info.value) == e @pytest.mark.usefixtures("app_ctx") -def test_no_default_url(app: Flask) -> None: +def test_no_default_url(app: Flask, model_class: t.Any) -> None: del app.config["SQLALCHEMY_DATABASE_URI"] app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app, engine_options={"echo": True}) + db = SQLAlchemy(app, model_class=model_class, engine_options={"echo": True}) assert None not in db.engines assert "a" in db.engines @pytest.mark.usefixtures("app_ctx") -def test_sqlite_relative_path(app: Flask) -> None: +def test_sqlite_relative_path(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.db" - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) db.create_all() assert not isinstance(db.engine.pool, sa.pool.StaticPool) db_path = db.engine.url.database @@ -99,9 +100,9 @@ def test_sqlite_relative_path(app: Flask) -> None: @pytest.mark.usefixtures("app_ctx") -def test_sqlite_driver_level_uri(app: Flask) -> None: +def test_sqlite_driver_level_uri(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///file:test.db?uri=true" - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) db.create_all() db_path = db.engine.url.database assert db_path is not None @@ -110,17 +111,21 @@ def test_sqlite_driver_level_uri(app: Flask) -> None: @unittest.mock.patch.object(SQLAlchemy, "_make_engine", autospec=True) -def test_sqlite_memory_defaults(make_engine: unittest.mock.Mock, app: Flask) -> None: - SQLAlchemy(app) +def test_sqlite_memory_defaults( + make_engine: unittest.mock.Mock, app: Flask, model_class: t.Any +) -> None: + SQLAlchemy(app, model_class=model_class) options = make_engine.call_args[0][2] assert options["poolclass"] is sa.pool.StaticPool assert options["connect_args"]["check_same_thread"] is False @unittest.mock.patch.object(SQLAlchemy, "_make_engine", autospec=True) -def test_mysql_defaults(make_engine: unittest.mock.Mock, app: Flask) -> None: +def test_mysql_defaults( + make_engine: unittest.mock.Mock, app: Flask, model_class: t.Any +) -> None: app.config["SQLALCHEMY_DATABASE_URI"] = "mysql:///test" - SQLAlchemy(app) + SQLAlchemy(app, model_class=model_class) options = make_engine.call_args[0][2] assert options["pool_recycle"] == 7200 assert options["url"].query["charset"] == "utf8mb4" diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 26ad5358..706609ae 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -55,7 +55,7 @@ class Base(sa_orm.DeclarativeBase): def test_metadata_from_custom_model(model_class: t.Any) -> None: - if model_class is not None: + if model_class is not Model: # In 2.x, SQLAlchemy creates the metadata attribute base = model_class else: @@ -77,17 +77,14 @@ def test_custom_metadata_overrides_custom_model_legacy() -> None: def test_metadata_per_bind(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - if model_class is not None: - db = SQLAlchemy(app, model_class=model_class) - else: - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) assert db.metadatas["a"] is not db.metadata assert db.metadatas["a"].info["bind_key"] == "a" def test_copy_naming_convention(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - if model_class is not None: + if model_class is not Model: model_class.metadata = sa.MetaData( naming_convention={"pk": "spk_%(table_name)s"} ) From 6f4c1804e2d214f1cc8f971a79ca259ea5e9037d Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 22 Aug 2023 13:39:33 -0700 Subject: [PATCH 20/27] Parameterize more tests to use 2.x --- tests/test_extension_object.py | 45 ++++++++++--- tests/test_model.py | 42 +++++++++++- tests/test_record_queries.py | 29 +++++++-- tests/test_session.py | 102 ++++++++++++++++++++++-------- tests/test_track_modifications.py | 38 ++++++++--- 5 files changed, 206 insertions(+), 50 deletions(-) diff --git a/tests/test_extension_object.py b/tests/test_extension_object.py index 2830077b..b7ffcb06 100644 --- a/tests/test_extension_object.py +++ b/tests/test_extension_object.py @@ -4,6 +4,7 @@ import pytest import sqlalchemy as sa +import sqlalchemy.orm as sa_orm from flask import Flask from werkzeug.exceptions import NotFound @@ -22,17 +23,45 @@ def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None: db.get_or_404(Todo, 2) -def test_get_or_404_kwargs(app: Flask) -> None: +def test_get_or_404_kwargs(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_RECORD_QUERIES"] = True - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) - class User(db.Model): - id = sa.Column(db.Integer, primary_key=True) # type: ignore[var-annotated] + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): - class Todo(db.Model): - id = sa.Column(sa.Integer, primary_key=True) - user_id = sa.Column(sa.ForeignKey(User.id)) # type: ignore[var-annotated] - user = db.relationship(User) + class User(db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.Integer, primary_key=True, init=False + ) + + class Todo(db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.Integer, primary_key=True, init=False + ) + user_id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.ForeignKey(User.id), init=False + ) + user: sa_orm.Mapped[User] = sa_orm.relationship(User) + + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): + + class User(db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + + class Todo(db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + user_id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.ForeignKey(User.id)) + user: sa_orm.Mapped[User] = sa_orm.relationship(User) + + else: + + class User(db.Model): # type: ignore[no-redef] + id = sa.Column(db.Integer, primary_key=True) # type: ignore[var-annotated] + + class Todo(db.Model): # type: ignore[no-redef] + id = sa.Column(sa.Integer, primary_key=True) + user_id = sa.Column(sa.ForeignKey(User.id)) # type: ignore[var-annotated] + user = db.relationship(User) with app.app_context(): db.create_all() diff --git a/tests/test_model.py b/tests/test_model.py index 1b316f0c..41ee1e2c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -13,7 +13,7 @@ from flask_sqlalchemy.model import Model -def test_default_model_class(app: Flask) -> None: +def test_default_model_class_1x(app: Flask) -> None: db = SQLAlchemy(app) assert db.Model.query_class is db.Query @@ -22,7 +22,7 @@ def test_default_model_class(app: Flask) -> None: assert isinstance(db.Model, DefaultMeta) -def test_custom_model_class(app: Flask) -> None: +def test_custom_model_class_1x(app: Flask) -> None: class CustomModel(Model): pass @@ -33,7 +33,7 @@ class CustomModel(Model): @pytest.mark.usefixtures("app_ctx") @pytest.mark.parametrize("base", [Model, object]) -def test_custom_declarative_class(app: Flask, base: t.Any) -> None: +def test_custom_declarative_class_1x(app: Flask, base: t.Any) -> None: class CustomMeta(DefaultMeta): pass @@ -44,6 +44,42 @@ class CustomMeta(DefaultMeta): assert "query" in db.Model.__dict__ +def test_declarativebase_2x(app: Flask) -> None: + class Base(sa_orm.DeclarativeBase): + pass + + db = SQLAlchemy(app, model_class=Base) + assert issubclass(db.Model, sa_orm.DeclarativeBase) + assert isinstance(db.Model, sa_orm.decl_api.DeclarativeAttributeIntercept) + + +def test_declarativebasenometa_2x(app: Flask) -> None: + class Base(sa_orm.DeclarativeBaseNoMeta): + pass + + db = SQLAlchemy(app, model_class=Base) + assert issubclass(db.Model, sa_orm.DeclarativeBaseNoMeta) + assert not isinstance(db.Model, sa_orm.decl_api.DeclarativeAttributeIntercept) + + +def test_declarativebasemapped_2x(app: Flask) -> None: + class Base(sa_orm.DeclarativeBase, sa_orm.MappedAsDataclass): + pass + + db = SQLAlchemy(app, model_class=Base) + assert issubclass(db.Model, sa_orm.DeclarativeBase) + assert isinstance(db.Model, sa_orm.decl_api.DCTransformDeclarative) + + +def test_declarativebasenometamapped_2x(app: Flask) -> None: + class Base(sa_orm.DeclarativeBaseNoMeta, sa_orm.MappedAsDataclass): + pass + + db = SQLAlchemy(app, model_class=Base) + assert issubclass(db.Model, sa_orm.DeclarativeBaseNoMeta) + assert isinstance(db.Model, sa_orm.decl_api.DCTransformDeclarative) + + @pytest.mark.usefixtures("app_ctx") def test_model_repr(db: SQLAlchemy) -> None: class User(db.Model): diff --git a/tests/test_record_queries.py b/tests/test_record_queries.py index 50c5cb76..c5cc73a2 100644 --- a/tests/test_record_queries.py +++ b/tests/test_record_queries.py @@ -4,6 +4,7 @@ import pytest import sqlalchemy as sa +import sqlalchemy.orm as sa_orm from flask import Flask from flask_sqlalchemy import SQLAlchemy @@ -15,15 +16,35 @@ def test_query_info(app: Flask) -> None: app.config["SQLALCHEMY_RECORD_QUERIES"] = True db = SQLAlchemy(app) - class Example(db.Model): - id = sa.Column(sa.Integer, primary_key=True) + # Copied and pasted from conftest.py + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): + + class Todo(db.Model): + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.Integer, init=False, primary_key=True + ) + title: sa_orm.Mapped[str] = sa_orm.mapped_column( + sa.String, nullable=True, default=None + ) + + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): + + class Todo(db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True) + + else: + + class Todo(db.Model): # type: ignore[no-redef] + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String) db.create_all() - db.session.execute(sa.select(Example).filter(Example.id < 5)).scalars() + db.session.execute(sa.select(Todo).filter(Todo.id < 5)).scalars() info = get_recorded_queries()[-1] assert info.statement is not None assert "SELECT" in info.statement - assert "FROM example" in info.statement + assert "FROM todo" in info.statement assert info.parameters[0][0] == 5 assert info.duration == info.end_time - info.start_time assert os.path.join("tests", "test_record_queries.py:") in info.location diff --git a/tests/test_session.py b/tests/test_session.py index 40dd0233..a274739a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,7 +1,10 @@ from __future__ import annotations +import typing as t + import pytest import sqlalchemy as sa +import sqlalchemy.orm as sa_orm from flask import Flask from flask_sqlalchemy import SQLAlchemy @@ -23,7 +26,7 @@ def test_scope(app: Flask, db: SQLAlchemy) -> None: assert first is not third -def test_custom_scope(app: Flask) -> None: +def test_custom_scope(app: Flask, model_class: t.Any) -> None: count = 0 def scope() -> int: @@ -31,7 +34,7 @@ def scope() -> int: count += 1 return count - db = SQLAlchemy(app, session_options={"scopefunc": scope}) + db = SQLAlchemy(app, model_class=model_class, session_options={"scopefunc": scope}) with app.app_context(): first = db.session() @@ -42,47 +45,94 @@ def scope() -> int: @pytest.mark.usefixtures("app_ctx") -def test_session_class(app: Flask) -> None: +def test_session_class(app: Flask, model_class: t.Any) -> None: class CustomSession(Session): pass - db = SQLAlchemy(app, session_options={"class_": CustomSession}) + db = SQLAlchemy( + app, model_class=model_class, session_options={"class_": CustomSession} + ) assert isinstance(db.session(), CustomSession) @pytest.mark.usefixtures("app_ctx") -def test_session_uses_bind_key(app: Flask) -> None: +def test_session_uses_bind_key(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) - class User(db.Model): - id = sa.Column(sa.Integer, primary_key=True) + if issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): - class Post(db.Model): - __bind_key__ = "a" - id = sa.Column(sa.Integer, primary_key=True) + class User(db.Model): + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) - assert db.session.get_bind(mapper=User) is db.engine - assert db.session.get_bind(mapper=Post) is db.engines["a"] + class Post(db.Model): + __bind_key__ = "a" + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + else: -@pytest.mark.usefixtures("app_ctx") -def test_get_bind_inheritance(app: Flask) -> None: - app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} - db = SQLAlchemy(app) + class User(db.Model): # type: ignore[no-redef] + id = sa.Column(sa.Integer, primary_key=True) - class User(db.Model): - __bind_key__ = "a" - id = sa.Column(sa.Integer, primary_key=True) - type = sa.Column(sa.String, nullable=False) + class Post(db.Model): # type: ignore[no-redef] + __bind_key__ = "a" + id = sa.Column(sa.Integer, primary_key=True) - __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} + assert db.session.get_bind(mapper=User) is db.engine + assert db.session.get_bind(mapper=Post) is db.engines["a"] - class Admin(User): - id = sa.Column(sa.ForeignKey(User.id), primary_key=True) - org = sa.Column(sa.String, nullable=False) - __mapper_args__ = {"polymorphic_identity": "admin"} +@pytest.mark.usefixtures("app_ctx") +def test_get_bind_inheritance(app: Flask, model_class: t.Any) -> None: + app.config["SQLALCHEMY_BINDS"] = {"a": "sqlite://"} + db = SQLAlchemy(app, model_class=model_class) + + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): + + class User(db.Model): + __bind_key__ = "a" + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.Integer, primary_key=True, init=False + ) + type: sa_orm.Mapped[str] = sa_orm.mapped_column( + sa.String, nullable=False, init=False + ) + __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} + + class Admin(User): + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.ForeignKey(User.id), primary_key=True, init=False + ) + org: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=False) + __mapper_args__ = {"polymorphic_identity": "admin"} + + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): + + class User(db.Model): + __bind_key__ = "a" + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + type: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=False) + __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} + + class Admin(User): + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.ForeignKey(User.id), primary_key=True + ) + org: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=False) + __mapper_args__ = {"polymorphic_identity": "admin"} + + else: + + class User(db.Model): # type: ignore[no-redef] + __bind_key__ = "a" + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.String, nullable=False) + __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} + + class Admin(User): # type: ignore[no-redef] + id = sa.Column(sa.ForeignKey(User.id), primary_key=True) + org = sa.Column(sa.String, nullable=False) + __mapper_args__ = {"polymorphic_identity": "admin"} db.create_all() db.session.add(Admin(org="pallets")) diff --git a/tests/test_track_modifications.py b/tests/test_track_modifications.py index ff895bd5..09d5ff53 100644 --- a/tests/test_track_modifications.py +++ b/tests/test_track_modifications.py @@ -4,6 +4,7 @@ import pytest import sqlalchemy as sa +import sqlalchemy.orm as sa_orm from flask import Flask from flask_sqlalchemy import SQLAlchemy @@ -14,13 +15,32 @@ @pytest.mark.usefixtures("app_ctx") -def test_track_modifications(app: Flask) -> None: +def test_track_modifications(app: Flask, model_class: t.Any) -> None: app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = True - db = SQLAlchemy(app) + db = SQLAlchemy(app, model_class=model_class) - class Example(db.Model): - id = sa.Column(sa.Integer, primary_key=True) - data = sa.Column(sa.String) + # Copied and pasted from conftest.py + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): + + class Todo(db.Model): + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + sa.Integer, init=False, primary_key=True + ) + title: sa_orm.Mapped[str] = sa_orm.mapped_column( + sa.String, nullable=True, default=None + ) + + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): + + class Todo(db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=True) + + else: + + class Todo(db.Model): # type: ignore[no-redef] + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.String) db.create_all() before: list[tuple[t.Any, str]] = [] @@ -38,7 +58,7 @@ def after_commit(sender: Flask, changes: list[tuple[t.Any, str]]) -> None: connect_after = models_committed.connected_to(after_commit, app) with connect_before, connect_after: - item = Example() + item = Todo() db.session.add(item) assert not before @@ -50,15 +70,15 @@ def after_commit(sender: Flask, changes: list[tuple[t.Any, str]]) -> None: assert before == after db.session.remove() - item = db.session.get(Example, 1) # type: ignore[assignment] - item.data = "test" # type: ignore[assignment] + item = db.session.get(Todo, 1) # type: ignore[assignment] + item.title = "test" # type: ignore[assignment] db.session.commit() assert len(before) == 1 assert before[0] == (item, "update") assert before == after db.session.remove() - item = db.session.get(Example, 1) # type: ignore[assignment] + item = db.session.get(Todo, 1) # type: ignore[assignment] db.session.delete(item) db.session.commit() assert len(before) == 1 From caee46ae0846d9808204d0db426657a597146960 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 22 Aug 2023 13:51:23 -0700 Subject: [PATCH 21/27] Update README --- README.rst | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 6299060e..92e3b29e 100644 --- a/README.rst +++ b/README.rst @@ -29,14 +29,19 @@ A Simple Example from flask import Flask from flask_sqlalchemy import SQLAlchemy + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column app = Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///example.sqlite" - db = SQLAlchemy(app) + + class Base(DeclarativeBase): + pass + + db = SQLAlchemy(app, model_class=Base) class User(db.Model): - id = db.Column(db.Integer, primary_key=True) - username = db.Column(db.String, unique=True, nullable=False) + id: Mapped[int] = mapped_column(db.Integer, primary_key=True) + username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) with app.app_context(): db.create_all() From 6d7a7e3c0fe3b3e94fba92712d568bb4a34458e0 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 22 Aug 2023 14:12:06 -0700 Subject: [PATCH 22/27] Style and sphinx checks --- src/flask_sqlalchemy/extension.py | 25 +++++++++++++++++-------- tests/test_extension_object.py | 4 ++-- tests/test_session.py | 6 +++--- tests/test_track_modifications.py | 2 +- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 9e6e4e50..300eb127 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -48,7 +48,9 @@ class _FSAModel(Model): metadata: sa.MetaData -def _get_2x_declarative_bases(model_class: t.Type[_FSA_MC]) -> list[t.Type]: +def _get_2x_declarative_bases( + model_class: t.Type[_FSA_MC], +) -> list[t.Type[t.Union[sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta]]]: return [ b for b in model_class.__bases__ @@ -81,10 +83,6 @@ class SQLAlchemy: :param app: Call :meth:`init_app` on this Flask application now. :param metadata: Use this as the default :class:`sqlalchemy.schema.MetaData`. Useful for setting a naming convention. - .. deprecated:: 3.1.0 - This parameter can still be used in conjunction with SQLAlchemy 1.x classes, - but is ignored when using SQLAlchemy 2.x style of declarative classes. - Instead, specify metadata on your Base class. :param session_options: Arguments used by :attr:`session` to create each session instance. A ``scopefunc`` key will be passed to the scoped session, not the session instance. See :class:`sqlalchemy.orm.sessionmaker` for a list of @@ -101,8 +99,16 @@ class SQLAlchemy: ``flask shell``. .. versionchanged:: 3.1.0 - Added the ``disable_autonaming`` parameter and changed ``model_class`` parameter - to accept a SQLAlchemy 2.0-style declarative base subclass. + The ``metadata`` parameter can still be used with SQLAlchemy 1.x classes, + but is ignored when using SQLAlchemy 2.x style of declarative classes. + Instead, specify metadata on your Base class. + + .. versionchanged:: 3.1.0 + Added the ``disable_autonaming`` parameter. + + .. versionchanged:: 3.1.0 + Changed ``model_class`` parameter to accepta SQLAlchemy 2.x + declarative base subclass. .. versionchanged:: 3.0 An active Flask application context is always required to access ``session`` and @@ -534,7 +540,10 @@ def _make_declarative_base( " Got: {}".format(model_class.__bases__) ) elif len(declarative_bases) == 1: - body = {"__fsa__": self, "metadata": model_class.metadata} + body = { + "__fsa__": self, + "metadata": model_class.metadata, # type: ignore[attr-defined] + } mixin_classes = [BindMixin, NameMixin, Model] if disable_autonaming: mixin_classes.remove(NameMixin) diff --git a/tests/test_extension_object.py b/tests/test_extension_object.py index b7ffcb06..0cb5a608 100644 --- a/tests/test_extension_object.py +++ b/tests/test_extension_object.py @@ -29,12 +29,12 @@ def test_get_or_404_kwargs(app: Flask, model_class: t.Any) -> None: if issubclass(db.Model, (sa_orm.MappedAsDataclass)): - class User(db.Model): # type: ignore[no-redef] + class User(db.Model): id: sa_orm.Mapped[int] = sa_orm.mapped_column( sa.Integer, primary_key=True, init=False ) - class Todo(db.Model): # type: ignore[no-redef] + class Todo(db.Model): id: sa_orm.Mapped[int] = sa_orm.mapped_column( sa.Integer, primary_key=True, init=False ) diff --git a/tests/test_session.py b/tests/test_session.py index a274739a..9e161741 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -108,13 +108,13 @@ class Admin(User): elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): - class User(db.Model): + class User(db.Model): # type: ignore[no-redef] __bind_key__ = "a" id: sa_orm.Mapped[int] = sa_orm.mapped_column(sa.Integer, primary_key=True) type: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.String, nullable=False) __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} - class Admin(User): + class Admin(User): # type: ignore[no-redef] id: sa_orm.Mapped[int] = sa_orm.mapped_column( sa.ForeignKey(User.id), primary_key=True ) @@ -130,7 +130,7 @@ class User(db.Model): # type: ignore[no-redef] __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} class Admin(User): # type: ignore[no-redef] - id = sa.Column(sa.ForeignKey(User.id), primary_key=True) + id = sa.Column(sa.ForeignKey(User.id), primary_key=True) # type: ignore[assignment] org = sa.Column(sa.String, nullable=False) __mapper_args__ = {"polymorphic_identity": "admin"} diff --git a/tests/test_track_modifications.py b/tests/test_track_modifications.py index 09d5ff53..0d596280 100644 --- a/tests/test_track_modifications.py +++ b/tests/test_track_modifications.py @@ -71,7 +71,7 @@ def after_commit(sender: Flask, changes: list[tuple[t.Any, str]]) -> None: db.session.remove() item = db.session.get(Todo, 1) # type: ignore[assignment] - item.title = "test" # type: ignore[assignment] + item.title = "test" db.session.commit() assert len(before) == 1 assert before[0] == (item, "update") From 1d877e825f2254b6130c1bfa81d182681734a83c Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 22 Aug 2023 16:05:59 -0700 Subject: [PATCH 23/27] Documentation updates- remove new mixins from ref, add some comments --- docs/api.rst | 15 +-------------- src/flask_sqlalchemy/model.py | 3 +++ 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 6bcc5fb7..793f1030 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -36,8 +36,7 @@ Metaclass mixins (SQLAlchemy 1.x) --------------------------------- If your code uses the SQLAlchemy 1.x API (the default for code that doesn't specify a ``model_class``), -then these mixins are automatically applied to the ``Model`` class. They can also be used -directly to create custom metaclasses. See :doc:`customizing` for more information. +then these mixins are automatically applied to the ``Model`` class. .. autoclass:: DefaultMeta @@ -46,18 +45,6 @@ directly to create custom metaclasses. See :doc:`customizing` for more informati .. autoclass:: NameMetaMixin -Base class mixins (SQLAlchemy 2.x) ----------------------------------- - -If your code uses the SQLAlchemy 2.x API by passing a subclass of ``DeclarativeBase`` -or ``DeclarativeBaseNoMeta`` as the ``model_class``, then the following classes -are automatically added as additional base classes. - -.. autoclass:: BindMixin - -.. autoclass:: NameMixin - - Session ------- diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index d1989a66..c6f9e5a9 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -300,8 +300,11 @@ def should_set_tablename(cls: type) -> bool: base is cls or base.__dict__.get("__abstract__", False) or not ( + # SQLAlchemy 1.x isinstance(base, sa_orm.DeclarativeMeta) + # 2.x: DeclarativeBas uses this as metaclass or isinstance(base, sa_orm.decl_api.DeclarativeAttributeIntercept) + # 2.x: DeclarativeBaseNoMeta doesn't use a metaclass or issubclass(base, sa_orm.DeclarativeBaseNoMeta) ) ) From a2db2160461093ae485b1e58cc356fdbc1e930c0 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Tue, 22 Aug 2023 16:45:42 -0700 Subject: [PATCH 24/27] Update abstract models and mixins --- docs/customizing.rst | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/docs/customizing.rst b/docs/customizing.rst index 7540a729..1c170686 100644 --- a/docs/customizing.rst +++ b/docs/customizing.rst @@ -59,25 +59,28 @@ they are created or updated. class TimestampModel(db.Model): __abstract__ = True - created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) - updated = db.Column(db.DateTime, onupdate=datetime.utcnow) + created: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, default=datetime.utcnow) + updated: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) class Author(db.Model): - ... + id: Mapped[int] = mapped_column(db.Integer, primary_key=True) + username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) class Post(TimestampModel): - ... + id: Mapped[int] = mapped_column(db.Integer, primary_key=True) + title: Mapped[str] = mapped_column(db.String, nullable=False) This can also be done with a mixin class, inheriting from ``db.Model`` separately. .. code-block:: python - class TimestampMixin: - created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) - updated = db.Column(db.DateTime, onupdate=datetime.utcnow) + class TimestampModel: + created: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, default=datetime.utcnow) + updated: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) - class Post(TimestampMixin, db.Model): - ... + class Post2(TimestampModel, db.Model): + id: Mapped[int] = mapped_column(db.Integer, primary_key=True) + title: Mapped[str] = mapped_column(db.String, nullable=False) Session Class @@ -168,14 +171,6 @@ on Flask-SQLAlchemy's detection and generation. The simple way to achieve that i set each ``__tablename__`` and not modify the base class. However, the table name generation can be disabled by setting `disable_autonaming=True` in the `SQLAlchemy` constructor. -Example code using the SQLAlchemy 1.x (legacy) API: - -.. code-block:: python - - db = SQLAlchemy(app, disable_autonaming=True) - -Example code using the SQLAlchemy 2.x declarative base: - .. code-block:: python class Base(sa_orm.DeclarativeBase): From cf4170adcdaedb226629f860b2745e895d959f5d Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Fri, 25 Aug 2023 16:28:51 -0700 Subject: [PATCH 25/27] Add test for declared_attr --- docs/customizing.rst | 83 +++++++------- src/flask_sqlalchemy/extension.py | 6 +- tests/test_model.py | 177 ++++++++++++++++++++++++++++++ 3 files changed, 220 insertions(+), 46 deletions(-) diff --git a/docs/customizing.rst b/docs/customizing.rst index 1c170686..05ab71db 100644 --- a/docs/customizing.rst +++ b/docs/customizing.rst @@ -21,29 +21,26 @@ joined-table inheritance. .. code-block:: python - from flask_sqlalchemy.model import Model - import sqlalchemy as sa - import sqlalchemy.orm + from sqlalchemy import Integer, String, ForeignKey + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, declared_attr - class IdModel(Model): - @sa.orm.declared_attr + class Base(DeclarativeBase): + @declared_attr.cascading + @classmethod def id(cls): for base in cls.__mro__[1:-1]: if getattr(base, "__table__", None) is not None: - type = sa.ForeignKey(base.id) - break - else: - type = sa.Integer + return mapped_column(ForeignKey(base.id), primary_key=True) + else: + return mapped_column(Integer, primary_key=True) - return sa.Column(type, primary_key=True) - - db = SQLAlchemy(model_class=IdModel) + db = SQLAlchemy(app, model_class=Base) class User(db.Model): - name = db.Column(db.String) + name: Mapped[str] = mapped_column(String) class Employee(User): - title = db.Column(db.String) + title: Mapped[str] = mapped_column(String) Abstract Models and Mixins @@ -56,31 +53,49 @@ they are created or updated. .. code-block:: python from datetime import datetime + from sqlalchemy import DateTime, Integer, String + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, declared_attr class TimestampModel(db.Model): __abstract__ = True - created: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, default=datetime.utcnow) - updated: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + created: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow) + updated: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) class Author(db.Model): - id: Mapped[int] = mapped_column(db.Integer, primary_key=True) - username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + username: Mapped[str] = mapped_column(String, unique=True, nullable=False) class Post(TimestampModel): - id: Mapped[int] = mapped_column(db.Integer, primary_key=True) - title: Mapped[str] = mapped_column(db.String, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + title: Mapped[str] = mapped_column(String, nullable=False) This can also be done with a mixin class, inheriting from ``db.Model`` separately. .. code-block:: python - class TimestampModel: - created: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, default=datetime.utcnow) - updated: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + class TimestampMixin: + created: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow) + updated: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + class Post(TimestampMixin, db.Model): + id: Mapped[int] = mapped_column(Integer, primary_key=True) + title: Mapped[str] = mapped_column(String, nullable=False) + - class Post2(TimestampModel, db.Model): - id: Mapped[int] = mapped_column(db.Integer, primary_key=True) - title: Mapped[str] = mapped_column(db.String, nullable=False) +Disabling Table Name Generation +------------------------------- + +Some projects prefer to set each model's ``__tablename__`` manually rather than relying +on Flask-SQLAlchemy's detection and generation. The simple way to achieve that is to +set each ``__tablename__`` and not modify the base class. However, the table name +generation can be disabled by setting `disable_autonaming=True` in the `SQLAlchemy` constructor. + +.. code-block:: python + + class Base(sa_orm.DeclarativeBase): + pass + + db = SQLAlchemy(app, model_class=Base, disable_autonaming=True) Session Class @@ -161,19 +176,3 @@ To customize only ``session.query``, pass the ``query_cls`` key to the .. code-block:: python db = SQLAlchemy(session_options={"query_cls": GetOrQuery}) - - -Disabling Table Name Generation -------------------------------- - -Some projects prefer to set each model's ``__tablename__`` manually rather than relying -on Flask-SQLAlchemy's detection and generation. The simple way to achieve that is to -set each ``__tablename__`` and not modify the base class. However, the table name -generation can be disabled by setting `disable_autonaming=True` in the `SQLAlchemy` constructor. - -.. code-block:: python - - class Base(sa_orm.DeclarativeBase): - pass - - db = SQLAlchemy(app, model_class=Base, disable_autonaming=True) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 300eb127..0ce065a5 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -540,10 +540,8 @@ def _make_declarative_base( " Got: {}".format(model_class.__bases__) ) elif len(declarative_bases) == 1: - body = { - "__fsa__": self, - "metadata": model_class.metadata, # type: ignore[attr-defined] - } + body = dict(model_class.__dict__) # type: ignore[arg-type] + body["__fsa__"] = self mixin_classes = [BindMixin, NameMixin, Model] if disable_autonaming: mixin_classes.remove(NameMixin) diff --git a/tests/test_model.py b/tests/test_model.py index 41ee1e2c..0968a1e2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from datetime import datetime import pytest import sqlalchemy as sa @@ -80,6 +81,182 @@ class Base(sa_orm.DeclarativeBaseNoMeta, sa_orm.MappedAsDataclass): assert isinstance(db.Model, sa_orm.decl_api.DCTransformDeclarative) +@pytest.mark.usefixtures("app_ctx") +def test_declaredattr(app: Flask, model_class: t.Any) -> None: + if model_class is Model: + + class IdModel(Model): + @sa.orm.declared_attr + @classmethod + def id(cls: type[Model]): # type: ignore[no-untyped-def] + for base in cls.__mro__[1:-1]: + if getattr(base, "__table__", None) is not None and hasattr( + base, "id" + ): + return sa.Column(sa.ForeignKey(base.id), primary_key=True) + return sa.Column(sa.Integer, primary_key=True) + + db = SQLAlchemy(app, model_class=IdModel) + + class User(db.Model): + name = db.Column(db.String) + + class Employee(User): + title = db.Column(db.String) + + else: + + class Base(sa_orm.DeclarativeBase): + @sa_orm.declared_attr + @classmethod + def id(cls: type[sa_orm.DeclarativeBase]) -> sa_orm.Mapped[int]: + for base in cls.__mro__[1:-1]: + if getattr(base, "__table__", None) is not None and hasattr( + base, "id" + ): + return sa_orm.mapped_column( + db.ForeignKey(base.id), primary_key=True + ) + return sa_orm.mapped_column(db.Integer, primary_key=True) + + db = SQLAlchemy(app, model_class=Base) + + class User(db.Model): # type: ignore[no-redef] + name: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String) + + class Employee(User): # type: ignore[no-redef] + title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String) + + db.create_all() + db.session.add(Employee(name="Emp Loyee", title="Admin")) + db.session.commit() + user = db.session.execute(db.select(User)).scalar() + employee = db.session.execute(db.select(Employee)).scalar() + assert user is not None + assert employee is not None + assert user.id == 1 + assert employee.id == 1 + + +@pytest.mark.usefixtures("app_ctx") +def test_abstractmodel(app: Flask, model_class: t.Any) -> None: + db = SQLAlchemy(app, model_class=model_class) + + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): + + class TimestampModel(db.Model): + __abstract__ = True + created: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, nullable=False, insert_default=datetime.utcnow, init=False + ) + updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, + insert_default=datetime.utcnow, + onupdate=datetime.utcnow, + init=False, + ) + + class Post(TimestampModel): + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + db.Integer, primary_key=True, init=False + ) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False) + + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): + + class TimestampModel(db.Model): # type: ignore[no-redef] + __abstract__ = True + created: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, nullable=False, default=datetime.utcnow + ) + updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + class Post(TimestampModel): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column(db.Integer, primary_key=True) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False) + + else: + + class TimestampModel(db.Model): # type: ignore[no-redef] + __abstract__ = True + created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) + updated = db.Column( + db.DateTime, onupdate=datetime.utcnow, default=datetime.utcnow + ) + + class Post(TimestampModel): # type: ignore[no-redef] + id = db.Column(db.Integer, primary_key=True) + title = db.Column(db.String, nullable=False) + + db.create_all() + db.session.add(Post(title="Admin Post")) + db.session.commit() + post = db.session.execute(db.select(Post)).scalar() + assert post is not None + assert post.created is not None + assert post.updated is not None + + +@pytest.mark.usefixtures("app_ctx") +def test_mixinmodel(app: Flask, model_class: t.Any) -> None: + db = SQLAlchemy(app, model_class=model_class) + + if issubclass(db.Model, (sa_orm.MappedAsDataclass)): + + class TimestampMixin(sa_orm.MappedAsDataclass): + created: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, nullable=False, insert_default=datetime.utcnow, init=False + ) + updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, + insert_default=datetime.utcnow, + onupdate=datetime.utcnow, + init=False, + ) + + class Post(TimestampMixin, db.Model): + id: sa_orm.Mapped[int] = sa_orm.mapped_column( + db.Integer, primary_key=True, init=False + ) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False) + + elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)): + + class TimestampMixin: # type: ignore[no-redef] + created: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, nullable=False, default=datetime.utcnow + ) + updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column( + db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + class Post(TimestampMixin, db.Model): # type: ignore[no-redef] + id: sa_orm.Mapped[int] = sa_orm.mapped_column(db.Integer, primary_key=True) + title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False) + + else: + + class TimestampMixin: # type: ignore[no-redef] + created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow) + updated = db.Column( + db.DateTime, onupdate=datetime.utcnow, default=datetime.utcnow + ) + + class Post(TimestampMixin, db.Model): # type: ignore[no-redef] + id = db.Column(db.Integer, primary_key=True) + title = db.Column(db.String, nullable=False) + + db.create_all() + db.session.add(Post(title="Admin Post")) + db.session.commit() + post = db.session.execute(db.select(Post)).scalar() + assert post is not None + assert post.created is not None + assert post.updated is not None + + @pytest.mark.usefixtures("app_ctx") def test_model_repr(db: SQLAlchemy) -> None: class User(db.Model): From fe9e4de6e9083e2238321c2896088c0ada502dcb Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Sat, 26 Aug 2023 06:24:11 -0700 Subject: [PATCH 26/27] Make suggested typing fix --- src/flask_sqlalchemy/extension.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 0ce065a5..34a85ab6 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -32,13 +32,13 @@ # Type accepted for model_class argument -_FSA_MC = t.TypeVar( - "_FSA_MC", +_FSA_MCT = t.TypeVar( + "_FSA_MCT", bound=t.Union[ - Model, + type[Model], sa_orm.DeclarativeMeta, - sa_orm.DeclarativeBase, - sa_orm.DeclarativeBaseNoMeta, + type[sa_orm.DeclarativeBase], + type[sa_orm.DeclarativeBaseNoMeta], ], ) @@ -49,7 +49,7 @@ class _FSAModel(Model): def _get_2x_declarative_bases( - model_class: t.Type[_FSA_MC], + model_class: _FSA_MCT, ) -> list[t.Type[t.Union[sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta]]]: return [ b @@ -168,7 +168,7 @@ def __init__( metadata: sa.MetaData | None = None, session_options: dict[str, t.Any] | None = None, query_class: type[Query] = Query, - model_class: t.Type[_FSA_MC] = Model, # type: ignore[assignment] + model_class: _FSA_MCT = Model, # type: ignore[assignment] engine_options: dict[str, t.Any] | None = None, add_models_to_shell: bool = True, disable_autonaming: bool = False, @@ -502,7 +502,7 @@ def __new__( def _make_declarative_base( self, - model_class: t.Type[_FSA_MC], + model_class: _FSA_MCT, disable_autonaming: bool = False, ) -> t.Type[_FSAModel]: """Create a SQLAlchemy declarative model class. The result is available as @@ -540,7 +540,7 @@ def _make_declarative_base( " Got: {}".format(model_class.__bases__) ) elif len(declarative_bases) == 1: - body = dict(model_class.__dict__) # type: ignore[arg-type] + body = dict(model_class.__dict__) body["__fsa__"] = self mixin_classes = [BindMixin, NameMixin, Model] if disable_autonaming: @@ -558,7 +558,7 @@ def _make_declarative_base( metadata=metadata, cls=model_class, name="Model", metaclass=metaclass ) else: - model = model_class + model = model_class # type: ignore[assignment] if None not in self.metadatas: # Use the model's metadata as the default metadata. From fdeec1d0d98669cc612e1f69d6875f9c1e4c6c45 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Sat, 26 Aug 2023 06:30:44 -0700 Subject: [PATCH 27/27] Adjusted to work in 3.8 --- src/flask_sqlalchemy/extension.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 34a85ab6..132ead56 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -35,10 +35,10 @@ _FSA_MCT = t.TypeVar( "_FSA_MCT", bound=t.Union[ - type[Model], + t.Type[Model], sa_orm.DeclarativeMeta, - type[sa_orm.DeclarativeBase], - type[sa_orm.DeclarativeBaseNoMeta], + t.Type[sa_orm.DeclarativeBase], + t.Type[sa_orm.DeclarativeBaseNoMeta], ], )