Skip to content

Commit

Permalink
Handle null FK values
Browse files Browse the repository at this point in the history
  • Loading branch information
ababic committed Feb 4, 2024
1 parent 784232c commit dd88603
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
13 changes: 10 additions & 3 deletions modelcluster/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from django.core.exceptions import FieldDoesNotExist
from django.db.models import Model, prefetch_related_objects

from modelcluster.utils import extract_field_value, get_model_field, sort_by_fields
from modelcluster.utils import NullRelationshipValueEncountered, extract_field_value, get_model_field, sort_by_fields


# Constructor for test functions that determine whether an object passes some boolean condition
Expand Down Expand Up @@ -369,6 +369,13 @@ def _build_test_function_from_filter(model, key_clauses, val):
return constructor(model, attribute_name, val)


def _run_test(test, obj):
try:
return test(obj)
except NullRelationshipValueEncountered:
return False


class FakeQuerySetIterable:
def __init__(self, queryset):
self.queryset = queryset
Expand Down Expand Up @@ -438,7 +445,7 @@ def filter(self, **kwargs):

clone = self.get_clone(results=[
obj for obj in self.results
if all([test(obj) for test in filters])
if all([_run_test(test, obj) for test in filters])
])
return clone

Expand All @@ -447,7 +454,7 @@ def exclude(self, **kwargs):

clone = self.get_clone(results=[
obj for obj in self.results
if not all([test(obj) for test in filters])
if not all([_run_test(test, obj) for test in filters])
])
return clone

Expand Down
28 changes: 23 additions & 5 deletions modelcluster/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import lru_cache
from django.core.exceptions import FieldDoesNotExist
from django.db.models import ManyToManyField, ManyToManyRel
from django.db.models import ManyToManyField, ManyToManyRel, Model

REL_DELIMETER = "__"

Expand All @@ -9,6 +9,10 @@ class ManyToManyTraversalError(ValueError):
pass


class NullRelationshipValueEncountered(Exception):
pass


class TraversedRelationship:
__slots__ = ['from_model', 'field']

Expand Down Expand Up @@ -100,17 +104,31 @@ def extract_field_value(obj, key, pk_only=False, suppress_fielddoesnotexist=Fals
to get ``None`` values instead.
"""
source = obj
for attr in key.split(REL_DELIMETER):
if hasattr(source, attr):
value = getattr(source, attr)
latest_obj = obj
segments = key.split(REL_DELIMETER)
for i, segment in enumerate(segments, start=1):
if hasattr(source, segment):
value = getattr(source, segment)
if isinstance(value, Model):
latest_obj = value
if value is None and i < len(segments):
raise NullRelationshipValueEncountered(
"'{key}' cannot be reached for {obj} because a None value "
"was encountered at {model_class}.{field_name}".format(
key=key,
obj=str(obj),
model_class=latest_obj._meta.label,
field_name=segment
)
)
source = value
continue
elif suppress_fielddoesnotexist:
return None
else:
raise FieldDoesNotExist(
"'{name}' is not a valid field name for {model}".format(
name=attr, model=type(source)
name=segment, model=type(source)
)
)
if pk_only and hasattr(value, 'pk'):
Expand Down

0 comments on commit dd88603

Please sign in to comment.