Skip to content

Commit

Permalink
add prefetch support
Browse files Browse the repository at this point in the history
  • Loading branch information
knifecake committed Dec 22, 2024
1 parent ee7cc7b commit 52e6cda
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 27 deletions.
6 changes: 6 additions & 0 deletions anchor/models/attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from anchor.models.base import BaseModel


class AttachmentManager(models.Manager):
pass


class Attachment(BaseModel):
class Meta:
constraints = (
Expand All @@ -22,6 +26,8 @@ class Meta:
),
)

objects = AttachmentManager()

blob = models.ForeignKey(
"anchor.Blob",
on_delete=models.PROTECT,
Expand Down
3 changes: 3 additions & 0 deletions anchor/models/fields/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .single_attachment import SingleAttachmentField

__all__ = ["SingleAttachmentField"]
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
from typing import Callable

from django.contrib.contenttypes.fields import GenericRelation
from django.contrib.contenttypes.fields import GenericRel, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models import Model
from django.db.models.fields.related_descriptors import ReverseOneToOneDescriptor

from anchor.models import Attachment, Blob


class ReverseSingleAttachmentDescriptor:
class ReverseSingleAttachmentDescriptor(ReverseOneToOneDescriptor):
def __init__(
self,
related: GenericRel,
name: str,
upload_to: str | Callable[[Blob], str] = None,
backend: str = None,
):
self.related = related
self.name = name
self.upload_to = upload_to
self.backend = backend

def __get__(self, instance, cls=None):
if instance is None:
return self

return Attachment.objects.filter(
object_id=instance.id,
content_type=ContentType.objects.get_for_model(instance),
name=self.name,
order=0,
).get()

def __set__(self, instance, value):
if isinstance(value, Attachment):
value.object_id = instance.id
Expand Down Expand Up @@ -58,6 +50,50 @@ def __set__(self, instance, value):
defaults={"blob": blob},
)

def get_queryset(self, **hints):
return (
Attachment._base_manager.db_manager(hints=hints)
.select_related("blob")
.all()
)

def get_prefetch_querysets(self, instances, querysets=None):
if querysets and len(querysets) != 1:
raise ValueError(
"querysets argument of get_prefetch_querysets() should have a length "
"of 1."
)
queryset = querysets[0] if querysets else self.get_queryset()
queryset._add_hints(instance=instances[0])
queryset = queryset.filter(
object_id__in=(instance.id for instance in instances),
content_type=ContentType.objects.get_for_model(instances[0]),
name=self.name,
order=0,
)
rel_obj_attr = self.related.field.get_local_related_value

def instance_attr(i):
return tuple(
str(x) for x in self.related.field.get_foreign_related_value(i)
)

instances_dict = {instance_attr(inst): inst for inst in instances}

# Since we're going to assign directly in the cache,
# we must manage the reverse relation cache manually.
for rel_obj in queryset:
instance = instances_dict[rel_obj_attr(rel_obj)]
self.related.field.set_cached_value(rel_obj, instance)
return (
queryset,
rel_obj_attr,
instance_attr,
True,
self.related.cache_name,
False,
)


class SingleAttachmentField(GenericRelation):
def __init__(
Expand All @@ -80,16 +116,18 @@ def __init__(
kwargs["from_fields"] = []
kwargs["serialize"] = False

self.rel = self.rel_class(
self,
to="anchor.Attachment",
related_name="+",
related_query_name="+",
limit_choices_to=None,
)

# Bypass the GenericRelation constructor to be able to set editable=True
super(GenericRelation, self).__init__(
to="anchor.Attachment",
rel=self.rel_class(
self,
to="anchor.Attachment",
related_name="+",
related_query_name="+",
limit_choices_to=None,
),
rel=self.rel,
**kwargs,
)

Expand All @@ -99,7 +137,10 @@ def contribute_to_class(self, cls: type[Model], name: str, **kwargs) -> None:
cls,
name,
ReverseSingleAttachmentDescriptor(
name=name, upload_to=self.upload_to, backend=self.backend
related=self.rel,
name=name,
upload_to=self.upload_to,
backend=self.backend,
),
)

Expand All @@ -110,3 +151,11 @@ def formfield(self, **kwargs):
defaults.update(kwargs)

return FileField(**defaults)

def get_forward_related_filter(self, obj):
return {
"object_id": obj.id,
"content_type": ContentType.objects.get_for_model(obj),
"name": self.name,
"order": 0,
}
18 changes: 13 additions & 5 deletions demo/movies/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest import skip

from django.core.files.base import ContentFile
from django.test import TestCase

Expand All @@ -20,8 +18,18 @@ def setUpTestData(cls):
cls.movie.cover = ContentFile(b"test", name="test.txt")
cls.movie.save()

@skip("Not implemented yet")
def test_prefetch_cover(self):
def test_prefetch_cover_with_single_object(self):
movie = Movie.objects.prefetch_related("cover").get(id=self.movie.id)
with self.assertNumQueries(0):
self.assertEqual(movie.cover.byte_size, 4)
self.assertEqual(movie.cover.filename, "test.txt")

def test_prefetch_cover_with_multiple_objects(self):
movies = [Movie.objects.create(title="Test Movie %d" % i) for i in range(10)]
for movie in movies:
movie.cover = ContentFile(b"test", name="test.txt")
movie.save()

movies = list(Movie.objects.prefetch_related("cover").all())
with self.assertNumQueries(0):
for movie in movies:
self.assertEqual(movie.cover.filename, "test.txt")
3 changes: 2 additions & 1 deletion demo/movies/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@


class MovieDetailView(DetailView):
model = Movie
def get_queryset(self):
return Movie.objects.prefetch_related("cover")


class MovieUpdateView(UpdateView):
Expand Down

0 comments on commit 52e6cda

Please sign in to comment.