Skip to content

Commit

Permalink
lint: use type narrowing for orm.object_session() (#17582)
Browse files Browse the repository at this point in the history
Since `object_session(...)` can return `None`, we need to help mypy have
confidence that the session is indeed there.

One way to narrow the type is to assert that it's not `None`, and
therefore should not alert about `union-attr` problems.

Includes a couple of renames, and variable extraction where relevant.

Signed-off-by: Mike Fiedler <miketheman@gmail.com>
  • Loading branch information
miketheman authored Feb 11, 2025
1 parent 1a5b4fe commit 74b7a87
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 40 deletions.
31 changes: 31 additions & 0 deletions tests/unit/utils/db/test_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sqlalchemy.orm import object_session

from warehouse.db import Model
from warehouse.utils.db.orm import NoSessionError, orm_session_from_obj


def test_orm_session_from_obj_raises_with_no_session():

class FakeObject(Model):
__tablename__ = "fake_object"

obj = FakeObject()
# Confirm that the object does not have a session with the built-in
assert object_session(obj) is None

with pytest.raises(NoSessionError):
orm_session_from_obj(obj)
3 changes: 2 additions & 1 deletion warehouse/accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from warehouse.observations.models import HasObservations, HasObservers, ObservationKind
from warehouse.sitemap.models import SitemapMixin
from warehouse.utils.attrs import make_repr
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import TZDateTime, bool_false, datetime_now

if TYPE_CHECKING:
Expand Down Expand Up @@ -236,7 +237,7 @@ def has_primary_verified_email(self):

@property
def recent_events(self):
session = orm.object_session(self)
session = orm_session_from_obj(self)
last_ninety = datetime.datetime.now() - datetime.timedelta(days=90)
return (
session.query(User.Event)
Expand Down
5 changes: 2 additions & 3 deletions warehouse/cache/origin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

from itertools import chain

from sqlalchemy.orm.session import Session

from warehouse import db
from warehouse.cache.origin.derivers import html_cache_deriver
from warehouse.cache.origin.interfaces import IOriginCache
from warehouse.utils.db import orm_session_from_obj


@db.listens_for(db.Session, "after_flush")
Expand Down Expand Up @@ -139,7 +138,7 @@ def register_origin_cache_keys(config, klass, cache_keys=None, purge_keys=None):

def receive_set(attribute, config, target):
cache_keys = config.registry["cache_keys"]
session = Session.object_session(target)
session = orm_session_from_obj(target)
purges = session.info.setdefault("warehouse.cache.origin.purges", set())
key_maker = cache_keys[attribute]
keys = key_maker(target).purge
Expand Down
6 changes: 3 additions & 3 deletions warehouse/email/ses/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from sqlalchemy.dialects.postgresql import JSONB, UUID as PG_UUID
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm.session import object_session

from warehouse import db
from warehouse.accounts.models import Email as EmailAddress, UnverifyReasons
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import bool_false, datetime_now

MAX_TRANSIENT_BOUNCES = 5
Expand Down Expand Up @@ -217,9 +217,9 @@ def _get_email(self):
if self._email_message.missing:
return

db = object_session(self._email_message)
session = orm_session_from_obj(self._email_message)
email = (
db.query(EmailAddress)
session.query(EmailAddress)
.filter(EmailAddress.email == self._email_message.to)
.first()
)
Expand Down
4 changes: 2 additions & 2 deletions warehouse/legacy/api/xmlrpc/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from pyramid.exceptions import ConfigurationError
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.session import Session
from urllib3.util import parse_url

from warehouse import db
Expand All @@ -23,6 +22,7 @@
from warehouse.legacy.api.xmlrpc.cache.fncache import RedisLru
from warehouse.legacy.api.xmlrpc.cache.interfaces import IXMLRPCCache
from warehouse.legacy.api.xmlrpc.cache.services import NullXMLRPCCache, RedisXMLRPCCache
from warehouse.utils.db import orm_session_from_obj

__all__ = ["RedisLru"]

Expand All @@ -32,7 +32,7 @@

def receive_set(attribute, config, target):
cache_keys = config.registry["cache_keys"]
session = Session.object_session(target)
session = orm_session_from_obj(target)
purges = session.info.setdefault("warehouse.legacy.api.xmlrpc.cache.purges", set())
key_maker = cache_keys[attribute]
keys = key_maker(target).purge
Expand Down
14 changes: 5 additions & 9 deletions warehouse/organizations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from warehouse.authnz import Permissions
from warehouse.events.models import HasEvents
from warehouse.utils.attrs import make_repr
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import TZDateTime, bool_false, datetime_now

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -332,22 +333,17 @@ class Organization(OrganizationMixin, HasEvents, db.Model):
@property
def owners(self):
"""Return all users who are owners of the organization."""
session = orm_session_from_obj(self)
owner_roles = (
orm.object_session(self)
.query(User.id)
session.query(User.id)
.join(OrganizationRole.user)
.filter(
OrganizationRole.role_name == OrganizationRoleType.Owner,
OrganizationRole.organization == self,
)
.subquery()
)
return (
orm.object_session(self)
.query(User)
.join(owner_roles, User.id == owner_roles.c.id)
.all()
)
return session.query(User).join(owner_roles, User.id == owner_roles.c.id).all()

def record_event(self, *, tag, request: Request = None, additional=None):
"""Record organization name in events in case organization is ever deleted."""
Expand All @@ -358,7 +354,7 @@ def record_event(self, *, tag, request: Request = None, additional=None):
)

def __acl__(self):
session = orm.object_session(self)
session = orm_session_from_obj(self)

acls = [
(
Expand Down
37 changes: 16 additions & 21 deletions warehouse/packaging/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from warehouse.sitemap.models import SitemapMixin
from warehouse.utils import dotted_navigator, wheel
from warehouse.utils.attrs import make_repr
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import bool_false, bool_true, datetime_now

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -257,7 +258,7 @@ class Project(SitemapMixin, HasEvents, HasObservations, db.Model):
)

def __getitem__(self, version):
session = orm.object_session(self)
session = orm_session_from_obj(self)
canonical_version = packaging.utils.canonicalize_version(version)

try:
Expand Down Expand Up @@ -288,7 +289,7 @@ def __getitem__(self, version):
raise KeyError from None

def __acl__(self):
session = orm.object_session(self)
session = orm_session_from_obj(self)
acls = [
# TODO: Similar to `warehouse.accounts.models.User.__acl__`, we express the
# permissions here in terms of the permissions that the user has on
Expand Down Expand Up @@ -417,42 +418,36 @@ def documentation_url(self):
@property
def owners(self):
"""Return all users who are owners of the project."""
session = orm_session_from_obj(self)
owner_roles = (
orm.object_session(self)
.query(User.id)
session.query(User.id)
.join(Role.user)
.filter(Role.role_name == "Owner", Role.project == self)
.subquery()
)
return (
orm.object_session(self)
.query(User)
.join(owner_roles, User.id == owner_roles.c.id)
.all()
)
return session.query(User).join(owner_roles, User.id == owner_roles.c.id).all()

@property
def maintainers(self):
"""Return all users who are maintainers of the project."""
session = orm_session_from_obj(self)
maintainer_roles = (
orm.object_session(self)
.query(User.id)
session.query(User.id)
.join(Role.user)
.filter(Role.role_name == "Maintainer", Role.project == self)
.subquery()
)
return (
orm.object_session(self)
.query(User)
session.query(User)
.join(maintainer_roles, User.id == maintainer_roles.c.id)
.all()
)

@property
def all_versions(self):
session = orm_session_from_obj(self)
return (
orm.object_session(self)
.query(
session.query(
Release.version,
Release.created,
Release.is_prerelease,
Expand All @@ -466,9 +461,9 @@ def all_versions(self):

@property
def latest_version(self):
session = orm_session_from_obj(self)
return (
orm.object_session(self)
.query(Release.version, Release.created, Release.is_prerelease)
session.query(Release.version, Release.created, Release.is_prerelease)
.filter(Release.project == self, Release.yanked.is_(False))
.order_by(Release.is_prerelease.nullslast(), Release._pypi_ordering.desc())
.first()
Expand All @@ -477,7 +472,7 @@ def latest_version(self):
@property
def active_releases(self):
return (
orm.object_session(self)
orm_session_from_obj(self)
.query(Release)
.filter(Release.project == self, Release.yanked.is_(False))
.order_by(Release._pypi_ordering.desc())
Expand All @@ -487,7 +482,7 @@ def active_releases(self):
@property
def yanked_releases(self):
return (
orm.object_session(self)
orm_session_from_obj(self)
.query(Release)
.filter(Release.project == self, Release.yanked.is_(True))
.order_by(Release._pypi_ordering.desc())
Expand Down Expand Up @@ -747,7 +742,7 @@ def __table_args__(cls): # noqa
uploaded_via: Mapped[str | None]

def __getitem__(self, filename: str) -> File:
session: orm.Session = orm.object_session(self) # type: ignore[assignment]
session = orm_session_from_obj(self)

try:
return (
Expand Down
3 changes: 2 additions & 1 deletion warehouse/utils/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from warehouse.utils.db.orm import orm_session_from_obj
from warehouse.utils.db.query_printer import print_query

__all__ = ["print_query"]
__all__ = ["orm_session_from_obj", "print_query"]
33 changes: 33 additions & 0 deletions warehouse/utils/db/orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ORM utilities."""

from sqlalchemy.orm import Session, object_session


class NoSessionError(Exception):
"""Raised when there is no active SQLAlchemy session"""


def orm_session_from_obj(obj) -> Session:
"""
Returns the session from the ORM object.
Adds guard, but it should never happen.
The guard helps with type hinting, as the object_session function
returns Optional[Session] type.
"""
session = object_session(obj)
if not session:
raise NoSessionError("Object does not have a session")
return session

0 comments on commit 74b7a87

Please sign in to comment.