Skip to content

Commit

Permalink
feat: Add support for Q objects in FakeQuerySet filter (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
KIRA009 authored and gasman committed Feb 22, 2024
1 parent 67419c6 commit 8bf0faf
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Changelog

6.3 (xx.xx.xxxx)
~~~~~~~~~~~~~~~~
* Support filtering with Q objects (Shohan Dutta Roy)
* Fix: Correctly handle filtering on fields on related models when those fields have names that match a lookup type (Andy Babic)
* Fix: Correctly handle null foreign keys when traversing related fields (Andy Babic)

Expand Down
46 changes: 38 additions & 8 deletions modelcluster/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re

from django.core.exceptions import FieldDoesNotExist
from django.db.models import Model, prefetch_related_objects
from django.db.models import Model, Q, prefetch_related_objects

from modelcluster.utils import NullRelationshipValueEncountered, extract_field_value, get_model_field, sort_by_fields

Expand Down Expand Up @@ -537,11 +537,41 @@ def get_clone(self, results = None):
new.tuple_fields = self.tuple_fields
new.iterable_class = self.iterable_class
return new

def resolve_q_object(self, q_object):
connector = q_object.connector
filters = []

def _get_filters(self, **kwargs):
def test(filters):
def test_inner(obj):
result = False
if connector == Q.AND:
result = all([test(obj) for test in filters])
elif connector == Q.OR:
result = any([test(obj) for test in filters])
else:
result = sum([test(obj) for test in filters]) == 1
if q_object.negated:
return not result
return result
return test_inner

for child in q_object.children:
if isinstance(child, Q):
filters.append(self.resolve_q_object(child))
else:
key_clauses, val = child
filters.append(_build_test_function_from_filter(self.model, key_clauses.split('__'), val))

return test(filters)

def _get_filters(self, *args, **kwargs):
# a list of test functions; objects must pass all tests to be included
# in the filtered list
filters = []

for q_object in args:
filters.append(self.resolve_q_object(q_object))

for key, val in kwargs.items():
filters.append(
Expand All @@ -550,26 +580,26 @@ def _get_filters(self, **kwargs):

return filters

def filter(self, **kwargs):
filters = self._get_filters(**kwargs)
def filter(self, *args, **kwargs):
filters = self._get_filters(*args, **kwargs)

clone = self.get_clone(results=[
obj for obj in self.results
if all([test(obj) for test in filters])
])
return clone

def exclude(self, **kwargs):
filters = self._get_filters(**kwargs)
def exclude(self, *args, **kwargs):
filters = self._get_filters(*args, **kwargs)

clone = self.get_clone(results=[
obj for obj in self.results
if not all([test(obj) for test in filters])
])
return clone

def get(self, **kwargs):
clone = self.filter(**kwargs)
def get(self, *args, **kwargs):
clone = self.filter(*args, **kwargs)
result_count = clone.count()

if result_count == 0:
Expand Down
32 changes: 31 additions & 1 deletion tests/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from django.test import TestCase
from django.db import IntegrityError
from django.db.models import Prefetch
from django.db.models import Prefetch, Q

from modelcluster.models import get_all_child_relations
from modelcluster.queryset import FakeQuerySet
Expand Down Expand Up @@ -705,6 +705,36 @@ def test_filtering_via_reverse_foreignkey(self):
()
)

def test_filtering_via_models_Q_objects(self):
band = Band(
name="The Beatles",
members=[
BandMember(name="John Lennon", favourite_restaurant=self.strawberry_fields),
BandMember(name="Ringo Starr", favourite_restaurant=self.the_yellow_submarine),
],
)

self.assertEqual(
tuple(band.members.filter(Q(name="John Lennon") | Q(favourite_restaurant__name="The Yellow Submarine"))),
(band.members.get(name="John Lennon"), band.members.get(name="Ringo Starr"))
)
self.assertEqual(
tuple(band.members.filter(Q(name="John Lennon") & Q(favourite_restaurant__name="Strawberry Fields"))),
(band.members.get(name="John Lennon"),)
)
self.assertEqual(
tuple(band.members.filter(Q(name="John Lennon") & ~Q(favourite_restaurant__name="The Yellow Submarine"))),
(band.members.get(name="John Lennon"),)
)
self.assertEqual(
tuple(band.members.filter(Q(name="John Lennon") & ~Q(favourite_restaurant__name="Strawberry Fields"))),
()
)
self.assertEqual(
tuple(band.members.filter(Q(name="John Lennon") ^ Q(favourite_restaurant__name="The Yellow Submarine"))),
(band.members.get(name="John Lennon"), band.members.get(name="Ringo Starr"))
)

def test_ordering_accross_foreignkeys(self):
band = Band(
name="The Beatles",
Expand Down

0 comments on commit 8bf0faf

Please sign in to comment.