Skip to content

Commit c99ae11

Browse files
authored
Merge pull request #93 from roman-certn/main
Add support for GenericRelation field
2 parents e57725f + 84b9824 commit c99ae11

File tree

7 files changed

+205
-10
lines changed

7 files changed

+205
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Added support for Django's reverse generic relations (`GenericRelation` model field) ([#93](https://github.com/dabapps/django-readers/pull/93)).
12+
1013
### Changed
1114
- Add support for Django 5.0
1215

django_readers/qs.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor
12
from django.db.models import Prefetch, QuerySet
23
from django.db.models.constants import LOOKUP_SEP
34
from django.db.models.fields.related_descriptors import (
@@ -167,6 +168,37 @@ def prefetch_reverse_relationship(
167168
)
168169

169170

171+
def prefetch_reverse_generic_relationship(
172+
name,
173+
content_type_field_name,
174+
object_id_field_name,
175+
related_queryset,
176+
prepare_related_queryset=noop,
177+
to_attr=None,
178+
):
179+
"""
180+
Efficiently prefetch a reverse generic relationship: one where the field on the "parent"
181+
queryset is a `GenericRelation` field. We need to include this field in the query.
182+
"""
183+
return pipe(
184+
include_fields(name),
185+
prefetch_related(
186+
Prefetch(
187+
name,
188+
pipe(
189+
include_fields(
190+
"pk",
191+
content_type_field_name,
192+
object_id_field_name,
193+
),
194+
prepare_related_queryset,
195+
)(related_queryset),
196+
to_attr,
197+
)
198+
),
199+
)
200+
201+
170202
def prefetch_many_to_many_relationship(
171203
name, related_queryset, prepare_related_queryset=noop, to_attr=None
172204
):
@@ -246,5 +278,14 @@ def prepare(queryset):
246278
prepare_related_queryset,
247279
to_attr,
248280
)(queryset)
281+
if type(related_descriptor) is ReverseGenericManyToOneDescriptor:
282+
return prefetch_reverse_generic_relationship(
283+
name,
284+
related_descriptor.rel.field.content_type_field_name,
285+
related_descriptor.rel.field.object_id_field_name,
286+
related_descriptor.field.related_model.objects.all(),
287+
prepare_related_queryset,
288+
to_attr,
289+
)(queryset)
249290

250291
return prepare

django_readers/rest_framework.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import deepcopy
2+
from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor
23
from django.core.exceptions import ImproperlyConfigured
34
from django.utils.functional import cached_property
45
from django_readers import specs
@@ -124,10 +125,25 @@ def _get_child_serializer_kwargs(self, rel_info):
124125
kwargs["allow_null"] = True
125126
return kwargs
126127

128+
def _get_rel_info(self, rel_name):
129+
descriptor = getattr(self.model, rel_name)
130+
# Special case for reverse generic relations (GenericRelation field)
131+
# as these don't appear in rest-framework's rel_info
132+
if isinstance(descriptor, ReverseGenericManyToOneDescriptor):
133+
return model_meta.RelationInfo(
134+
model_field=descriptor.field,
135+
related_model=descriptor.field.related_model,
136+
to_many=True,
137+
to_field=None,
138+
has_through_model=False,
139+
reverse=True,
140+
)
141+
return self.info.relations[rel_name]
142+
127143
def visit_dict_item_list(self, key, value):
128144
# This is a relationship, so we recurse and create
129145
# a nested serializer to represent it
130-
rel_info = self.info.relations[key]
146+
rel_info = self._get_rel_info(key)
131147
capfirst = self._lowercase_with_underscores_to_capitalized_words(key)
132148
child_serializer_class = serializer_class_for_spec(
133149
f"{self.name}{capfirst}",
@@ -143,7 +159,7 @@ def visit_dict_item_dict(self, key, value):
143159
# do the same as the previous case, but handled
144160
# slightly differently to set the `source` correctly
145161
relationship_name, relationship_spec = next(iter(value.items()))
146-
rel_info = self.info.relations[relationship_name]
162+
rel_info = self._get_rel_info(relationship_name)
147163
capfirst = self._lowercase_with_underscores_to_capitalized_words(key)
148164
child_serializer_class = serializer_class_for_spec(
149165
f"{self.name}{capfirst}",

tests/models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1+
from django.contrib.contenttypes.fields import GenericRelation
12
from django.db import models
23

34

5+
class LogEntry(models.Model):
6+
content_type = models.ForeignKey(
7+
to="contenttypes.ContentType",
8+
on_delete=models.CASCADE,
9+
related_name="+",
10+
)
11+
object_pk = models.CharField(max_length=255)
12+
event = models.CharField(max_length=100)
13+
14+
415
class Group(models.Model):
516
name = models.CharField(max_length=100)
617

@@ -15,6 +26,9 @@ class Widget(models.Model):
1526
value = models.PositiveIntegerField(default=0)
1627
other = models.CharField(max_length=100, null=True)
1728
owner = models.ForeignKey(Owner, null=True, on_delete=models.SET_NULL)
29+
logs = GenericRelation(
30+
LogEntry, content_type_field="content_type", object_id_field="object_pk"
31+
)
1832

1933

2034
class Thing(models.Model):

tests/test_qs.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from django.contrib.contenttypes.models import ContentType
12
from django.db import connection
23
from django.db.models import Count
34
from django.test import TestCase
45
from django.test.utils import CaptureQueriesContext
56
from django_readers import qs
6-
from tests.models import Category, Owner, Widget
7+
from tests.models import Category, LogEntry, Owner, Widget
78
from unittest import mock
89

910

@@ -188,6 +189,84 @@ def test_prefetch_reverse_relationship(self):
188189
with self.assertNumQueries(0):
189190
self.assertEqual(owners[0].widget_set.all()[0].name, "test widget")
190191

192+
def test_prefetch_reverse_generic_relationship(self):
193+
widget = Widget.objects.create(name="test widget")
194+
LogEntry.objects.create(
195+
content_type=ContentType.objects.get_for_model(widget),
196+
object_pk=widget.id,
197+
event="CREATED",
198+
)
199+
200+
prepare = qs.pipe(
201+
qs.include_fields("name"),
202+
qs.prefetch_reverse_generic_relationship(
203+
"logs",
204+
"content_type",
205+
"object_pk",
206+
LogEntry.objects.all(),
207+
qs.include_fields("event"),
208+
),
209+
)
210+
211+
with CaptureQueriesContext(connection) as capture:
212+
widgets = list(prepare(Widget.objects.all()))
213+
214+
self.assertEqual(len(capture.captured_queries), 2)
215+
216+
self.assertEqual(
217+
capture.captured_queries[0]["sql"],
218+
"SELECT "
219+
'"tests_widget"."id", '
220+
'"tests_widget"."name" '
221+
"FROM "
222+
'"tests_widget"',
223+
)
224+
225+
content_type_id = ContentType.objects.get_for_model(Widget).pk
226+
227+
self.assertEqual(
228+
capture.captured_queries[1]["sql"],
229+
"SELECT "
230+
'"tests_logentry"."id", '
231+
'"tests_logentry"."content_type_id", '
232+
'"tests_logentry"."object_pk", '
233+
'"tests_logentry"."event" '
234+
"FROM "
235+
'"tests_logentry" '
236+
"WHERE "
237+
f'("tests_logentry"."content_type_id" = {content_type_id} AND '
238+
'"tests_logentry"."object_pk" IN '
239+
"('1'))",
240+
)
241+
242+
with self.assertNumQueries(0):
243+
self.assertEqual(widgets[0].logs.all()[0].event, "CREATED")
244+
245+
def test_prefetch_reverse_generic_relationship_with_to_attr(self):
246+
widget = Widget.objects.create(name="test widget")
247+
LogEntry.objects.create(
248+
content_type=ContentType.objects.get_for_model(widget),
249+
object_pk=widget.id,
250+
event="CREATED",
251+
)
252+
253+
prepare = qs.pipe(
254+
qs.include_fields("name"),
255+
qs.prefetch_reverse_generic_relationship(
256+
"logs",
257+
"content_type",
258+
"object_pk",
259+
LogEntry.objects.all(),
260+
qs.include_fields("event"),
261+
to_attr="history",
262+
),
263+
)
264+
265+
widgets = list(prepare(Widget.objects.all()))
266+
267+
with self.assertNumQueries(0):
268+
self.assertEqual(widgets[0].history[0].event, "CREATED")
269+
191270
def test_prefetch_reverse_relationship_only_loads_pk_and_related_name_by_default(
192271
self,
193272
):
@@ -358,6 +437,12 @@ def test_auto_prefetch_relationship(self):
358437
qs.auto_prefetch_relationship("category_set")(Widget.objects.all())
359438
mock_fn.assert_called_once()
360439

440+
with mock.patch(
441+
"django_readers.qs.prefetch_reverse_generic_relationship"
442+
) as mock_fn:
443+
qs.auto_prefetch_relationship("logs")(Widget.objects.all())
444+
mock_fn.assert_called_once()
445+
361446
def test_annotate_only_includes_fk_by_default(self):
362447
owner = Owner.objects.create(name="test owner")
363448
Widget.objects.create(name="test 1", owner=owner)

tests/test_rest_framework.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from django.contrib.contenttypes.models import ContentType
12
from django.core.exceptions import ImproperlyConfigured
23
from django.test import TestCase
34
from django_readers import pairs, qs
@@ -10,7 +11,7 @@
1011
from rest_framework import serializers
1112
from rest_framework.generics import ListAPIView, RetrieveAPIView
1213
from rest_framework.test import APIRequestFactory
13-
from tests.models import Category, Group, Owner, Widget
14+
from tests.models import Category, Group, LogEntry, Owner, Widget
1415
from textwrap import dedent
1516

1617

@@ -28,6 +29,7 @@ class WidgetListView(SpecMixin, ListAPIView):
2829
},
2930
]
3031
},
32+
{"logs": ["event"]},
3133
]
3234

3335

@@ -53,17 +55,32 @@ class CategoryDetailView(SpecMixin, RetrieveAPIView):
5355

5456
class RESTFrameworkTestCase(TestCase):
5557
def test_list(self):
56-
Widget.objects.create(
58+
widget = Widget.objects.create(
5759
name="test widget",
5860
owner=Owner.objects.create(
5961
name="test owner", group=Group.objects.create(name="test group")
6062
),
6163
)
64+
LogEntry.objects.create(
65+
content_type=ContentType.objects.get_for_model(widget),
66+
object_pk=widget.id,
67+
event="CREATED",
68+
)
69+
LogEntry.objects.create(
70+
content_type=ContentType.objects.get_for_model(widget),
71+
object_pk=widget.id,
72+
event="UPDATED",
73+
)
74+
LogEntry.objects.create(
75+
content_type=ContentType.objects.get_for_model(widget),
76+
object_pk=widget.id,
77+
event="DELETED",
78+
)
6279

6380
request = APIRequestFactory().get("/")
6481
view = WidgetListView.as_view()
6582

66-
with self.assertNumQueries(3):
83+
with self.assertNumQueries(4):
6784
response = view(request)
6885

6986
self.assertEqual(
@@ -77,6 +94,11 @@ def test_list(self):
7794
"name": "test group",
7895
},
7996
},
97+
"logs": [
98+
{"event": "CREATED"},
99+
{"event": "UPDATED"},
100+
{"event": "DELETED"},
101+
],
80102
}
81103
],
82104
)
@@ -180,12 +202,16 @@ def test_all_relationship_types(self):
180202
},
181203
]
182204
},
205+
{
206+
"logs": [
207+
"event",
208+
]
209+
},
183210
]
184211
},
185212
]
186213

187214
cls = serializer_class_for_spec("Owner", Owner, spec)
188-
189215
expected = dedent(
190216
"""\
191217
OwnerSerializer():
@@ -199,7 +225,9 @@ def test_all_relationship_types(self):
199225
thing = OwnerWidgetSetThingSerializer(read_only=True):
200226
name = CharField(max_length=100, read_only=True)
201227
related_widget = OwnerWidgetSetThingRelatedWidgetSerializer(allow_null=True, read_only=True, source='widget'):
202-
name = CharField(allow_null=True, max_length=100, read_only=True, required=False)"""
228+
name = CharField(allow_null=True, max_length=100, read_only=True, required=False)
229+
logs = OwnerWidgetSetLogsSerializer(allow_null=True, many=True, read_only=True):
230+
event = CharField(max_length=100, read_only=True)"""
203231
)
204232
self.assertEqual(repr(cls()), expected)
205233

tests/test_specs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from django.contrib.contenttypes.models import ContentType
12
from django.test import TestCase
23
from django_readers import specs
3-
from tests.models import Category, Group, Owner, Thing, Widget
4+
from tests.models import Category, Group, LogEntry, Owner, Thing, Widget
45

56

67
class SpecTestCase(TestCase):
@@ -28,20 +29,26 @@ def test_relationships(self):
2829
category = Category.objects.create(name="test category")
2930
category.widget_set.add(widget)
3031
Thing.objects.create(name="test thing", widget=widget)
32+
LogEntry.objects.create(
33+
content_type=ContentType.objects.get_for_model(widget),
34+
object_pk=widget.id,
35+
event="CREATED",
36+
)
3137

3238
prepare, project = specs.process(
3339
[
3440
"name",
3541
{"owner": ["name", {"widget_set": ["name"]}]},
3642
{"category_set": ["name", {"widget_set": ["name"]}]},
3743
{"thing": ["name", {"widget": ["name"]}]},
44+
{"logs": ["event"]},
3845
]
3946
)
4047

4148
with self.assertNumQueries(0):
4249
queryset = prepare(Widget.objects.all())
4350

44-
with self.assertNumQueries(7):
51+
with self.assertNumQueries(8):
4552
instance = queryset.first()
4653

4754
with self.assertNumQueries(0):
@@ -59,6 +66,7 @@ def test_relationships(self):
5966
{"name": "test category", "widget_set": [{"name": "test widget"}]},
6067
],
6168
"thing": {"name": "test thing", "widget": {"name": "test widget"}},
69+
"logs": [{"event": "CREATED"}],
6270
},
6371
)
6472

0 commit comments

Comments
 (0)