Skip to content

Commit 8f2dfc5

Browse files
committed
parametrize component registry identity #1288
1 parent 2b4d5ab commit 8f2dfc5

File tree

8 files changed

+101
-14
lines changed

8 files changed

+101
-14
lines changed

drf_spectacular/contrib/pydantic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_name(self, auto_schema, direction):
2323
# of the entry model, we simply use the class name as string for object. This hack may
2424
# create false positive warnings, so turn it off. However, this may suppress correct
2525
# warnings involving the entry class.
26+
# TODO suppression may be migrated to new ComponentIdentity system
2627
set_override(self.target, 'suppress_collision_warning', True)
2728
return self.target.__name__
2829

drf_spectacular/contrib/rest_framework_dataclasses.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from typing import Any
2+
13
from drf_spectacular.drainage import get_override, has_override
24
from drf_spectacular.extensions import OpenApiSerializerExtension
3-
from drf_spectacular.plumbing import get_doc
5+
from drf_spectacular.plumbing import ComponentIdentity, get_doc
46
from drf_spectacular.utils import Direction
57

68

@@ -18,6 +20,9 @@ def get_name(self):
1820
return get_override(self.target.dataclass_definition.dataclass_type, 'component_name')
1921
return self.target.dataclass_definition.dataclass_type.__name__
2022

23+
def get_identity(self, auto_schema, direction: Direction) -> Any:
24+
return ComponentIdentity(self.target.dataclass_definition.dataclass_type)
25+
2126
def strip_library_doc(self, schema):
2227
"""Strip the DataclassSerializer library documentation from the schema."""
2328
from rest_framework_dataclasses.serializers import DataclassSerializer

drf_spectacular/contrib/rest_polymorphic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from drf_spectacular.drainage import warn
22
from drf_spectacular.extensions import OpenApiSerializerExtension
33
from drf_spectacular.plumbing import (
4-
ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer,
4+
ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type,
5+
is_patched_serializer,
56
)
67
from drf_spectacular.settings import spectacular_settings
78
from drf_spectacular.types import OpenApiTypes
@@ -25,7 +26,7 @@ def map_serializer(self, auto_schema, direction):
2526
component = ResolvedComponent(
2627
name=auto_schema._get_serializer_name(sub_serializer, direction),
2728
type=ResolvedComponent.SCHEMA,
28-
object='virtual'
29+
object=ComponentIdentity('virtual')
2930
)
3031
typed_component = self.build_typed_component(
3132
auto_schema=auto_schema,

drf_spectacular/extensions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[
6868
""" return str for overriding default name extraction """
6969
return None
7070

71+
def get_identity(self, auto_schema: 'AutoSchema', direction: Direction) -> Any:
72+
""" return anything to compare instances of target. Target will be used by default. """
73+
return None
74+
7175
def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
7276
""" override for customized serializer mapping """
7377
return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True)

drf_spectacular/openapi.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,12 +1478,13 @@ def _get_response_for_code(self, serializer, status_code, media_types=None, dire
14781478
and is_serializer(serializer)
14791479
and (not is_list_serializer(serializer) or is_serializer(serializer.child))
14801480
):
1481-
paginated_name = self.get_paginated_name(self._get_serializer_name(serializer, "response"))
14821481
component = ResolvedComponent(
1483-
name=paginated_name,
1482+
name=self.get_paginated_name(self._get_serializer_name(serializer, 'response')),
14841483
type=ResolvedComponent.SCHEMA,
14851484
schema=paginator.get_paginated_response_schema(schema),
1486-
object=serializer.child if is_list_serializer(serializer) else serializer,
1485+
object=self.get_serializer_identity(
1486+
serializer.child if is_list_serializer(serializer) else serializer, 'response'
1487+
)
14871488
)
14881489
self.registry.register_on_missing(component)
14891490
schema = component.ref
@@ -1556,7 +1557,17 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> _
15561557

15571558
return result
15581559

1560+
def get_serializer_identity(self, serializer, direction: Direction) -> Any:
1561+
serializer_extension = OpenApiSerializerExtension.get_match(serializer)
1562+
if serializer_extension:
1563+
identity = serializer_extension.get_identity(self, direction)
1564+
if identity is not None:
1565+
return identity
1566+
1567+
return serializer
1568+
15591569
def get_serializer_name(self, serializer: serializers.Serializer, direction: Direction) -> str:
1570+
""" override this for custom behaviour """
15601571
return serializer.__class__.__name__
15611572

15621573
def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str:
@@ -1612,7 +1623,7 @@ def resolve_serializer(
16121623
component = ResolvedComponent(
16131624
name=self._get_serializer_name(serializer, direction, bypass_extensions),
16141625
type=ResolvedComponent.SCHEMA,
1615-
object=serializer,
1626+
object=self.get_serializer_identity(serializer, direction),
16161627
)
16171628
if component in self.registry:
16181629
return self.registry[component] # return component with schema

drf_spectacular/plumbing.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,17 @@ def ref(self) -> _SchemaType:
723723
return {'$ref': f'#/components/{self.type}/{self.name}'}
724724

725725

726+
class ComponentIdentity:
727+
""" A container class to make object/component comparison explicit """
728+
def __init__(self, obj):
729+
self.obj = obj
730+
731+
def __eq__(self, other):
732+
if isinstance(other, ComponentIdentity):
733+
return self.obj == other.obj
734+
return self.obj == other
735+
736+
726737
class ComponentRegistry:
727738
def __init__(self) -> None:
728739
self._components: Dict[Tuple[str, str], ResolvedComponent] = {}
@@ -746,17 +757,25 @@ def __contains__(self, component):
746757

747758
query_obj = component.object
748759
registry_obj = self._components[component.key].object
749-
query_class = query_obj if inspect.isclass(query_obj) else query_obj.__class__
750-
registry_class = query_obj if inspect.isclass(registry_obj) else registry_obj.__class__
760+
761+
if isinstance(query_obj, ComponentIdentity) or inspect.isclass(query_obj):
762+
query_id = query_obj
763+
else:
764+
query_id = query_obj.__class__
765+
766+
if isinstance(registry_obj, ComponentIdentity) or inspect.isclass(registry_obj):
767+
registry_id = registry_obj
768+
else:
769+
registry_id = registry_obj.__class__
751770

752771
suppress_collision_warning = (
753-
get_override(registry_class, 'suppress_collision_warning', False)
754-
or get_override(query_class, 'suppress_collision_warning', False)
772+
get_override(registry_id, 'suppress_collision_warning', False)
773+
or get_override(query_id, 'suppress_collision_warning', False)
755774
)
756-
if query_class != registry_class and not suppress_collision_warning:
775+
if query_id != registry_id and not suppress_collision_warning:
757776
warn(
758777
f'Encountered 2 components with identical names "{component.name}" and '
759-
f'different classes {query_class} and {registry_class}. This will very '
778+
f'different identities {query_id} and {registry_id}. This will very '
760779
f'likely result in an incorrect schema. Try renaming one.'
761780
)
762781
return True

tests/contrib/test_rest_framework_dataclasses.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,49 @@ def custom_name_via_serializer_decoration(request):
9090
generate_schema(None, patterns=urlpatterns),
9191
'tests/contrib/test_rest_framework_dataclasses.yml'
9292
)
93+
94+
95+
@pytest.mark.contrib('rest_framework_dataclasses')
96+
@pytest.mark.skipif(sys.version_info < (3, 7), reason='dataclass required by package')
97+
def test_rest_framework_dataclasses_class_reuse(no_warnings):
98+
from dataclasses import dataclass
99+
100+
from rest_framework_dataclasses.serializers import DataclassSerializer
101+
102+
@dataclass
103+
class Person:
104+
name: str
105+
age: int
106+
107+
@dataclass
108+
class Party:
109+
person: Person
110+
num_persons: int
111+
112+
class PartySerializer(DataclassSerializer[Party]):
113+
class Meta:
114+
dataclass = Party
115+
116+
class PersonSerializer(DataclassSerializer[Person]):
117+
class Meta:
118+
dataclass = Person
119+
120+
@extend_schema(responses=PartySerializer)
121+
@api_view()
122+
def party(request):
123+
pass # pragma: no cover
124+
125+
@extend_schema(responses=PersonSerializer)
126+
@api_view()
127+
def person(request):
128+
pass # pragma: no cover
129+
130+
urlpatterns = [
131+
path('party', person),
132+
path('person', party),
133+
]
134+
135+
schema = generate_schema(None, patterns=urlpatterns)
136+
# just existence is enough to check since its about no_warnings
137+
assert 'Person' in schema['components']['schemas']
138+
assert 'Party' in schema['components']['schemas']

tests/test_warnings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class X2Viewset(mixins.ListModelMixin, viewsets.GenericViewSet):
4949
generate_schema(None, patterns=router.urls)
5050

5151
stderr = capsys.readouterr().err
52-
assert 'Encountered 2 components with identical names "X" and different classes' in stderr
52+
assert 'Encountered 2 components with identical names "X" and different identities' in stderr
5353

5454

5555
def test_owned_serializer_naming_override_with_ref_name_collision(warnings):

0 commit comments

Comments
 (0)