diff --git a/api/api/serializers/audio_serializers.py b/api/api/serializers/audio_serializers.py index 0da59296267..95dc887c75d 100644 --- a/api/api/serializers/audio_serializers.py +++ b/api/api/serializers/audio_serializers.py @@ -37,10 +37,6 @@ class AudioCollectionRequestSerializer(PaginatedRequestSerializer): default=False, ) - @property - def needs_db(self) -> bool: - return super().needs_db or self.data["peaks"] - class AudioSearchRequestSerializer( AudioSearchRequestSourceSerializer, @@ -75,10 +71,6 @@ class AudioSearchRequestSerializer( default=False, ) - @property - def needs_db(self) -> bool: - return super().needs_db or self.data["peaks"] - def validate_internal__index(self, value): if not (index := super().validate_internal__index(value)): return None @@ -147,8 +139,6 @@ class Meta: used to generate Swagger documentation. """ - needs_db = True # for the 'thumbnail' field - audio_set = AudioSetSerializer( allow_null=True, help_text="Reference to set of which this track is a part.", diff --git a/api/api/serializers/image_serializers.py b/api/api/serializers/image_serializers.py index faf08d7b5bf..9238ce0b625 100644 --- a/api/api/serializers/image_serializers.py +++ b/api/api/serializers/image_serializers.py @@ -107,8 +107,6 @@ class Meta: used to generate Swagger documentation. """ - needs_db = True # for the 'height' and 'width' fields - ########################## # Additional serializers # diff --git a/api/api/serializers/media_serializers.py b/api/api/serializers/media_serializers.py index d8a52d4c1f8..961ce5b5027 100644 --- a/api/api/serializers/media_serializers.py +++ b/api/api/serializers/media_serializers.py @@ -81,10 +81,6 @@ def validate_page_size(self, value): return value - @property - def needs_db(self) -> bool: - return False - @extend_schema_serializer( # Hide unstable and internal fields from documentation. @@ -497,9 +493,6 @@ class Meta: used to generate Swagger documentation. """ - needs_db = False - """whether the serializer needs fields from the DB to process results""" - id = serializers.CharField( help_text="Our unique identifier for an open-licensed work.", source="identifier", diff --git a/api/api/views/media_views.py b/api/api/views/media_views.py index ace71bb86f6..0ae6bf66cd6 100644 --- a/api/api/views/media_views.py +++ b/api/api/views/media_views.py @@ -237,9 +237,7 @@ def get_media_results( serializer_context = search_context | self.get_serializer_context() - serializer_class = self.get_serializer() - if params.needs_db or serializer_class.needs_db: - results = self.get_db_results(results) + results = self.get_db_results(results) serializer = self.get_serializer(results, many=True, context=serializer_context) return self.get_paginated_response(serializer.data) @@ -280,9 +278,7 @@ def related(self, request, identifier=None, *_, **__): serializer_context = self.get_serializer_context() - serializer_class = self.get_serializer() - if serializer_class.needs_db: - results = self.get_db_results(results) + results = self.get_db_results(results) serializer = self.get_serializer(results, many=True, context=serializer_context) return self.get_paginated_response(serializer.data) diff --git a/api/test/test_auth.py b/api/test/test_auth.py index 6387b1818ed..117193bbaeb 100644 --- a/api/test/test_auth.py +++ b/api/test/test_auth.py @@ -9,7 +9,7 @@ from oauth2_provider.models import AccessToken from api.models import OAuth2Verification, ThrottledApplication - +from unittest.mock import patch cache_availability_params = pytest.mark.parametrize( "is_cache_reachable, cache_name", @@ -206,16 +206,21 @@ def test_unauthed_response_headers(client): ("asc", "2022-01-01"), ], ) -def test_sorting_authed( - client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on -): - # Prevent DB lookup for ES results because DB is empty. - monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False) - +def test_sorting_authed(client, test_auth_token_exchange, sort_dir, exp_indexed_on): time.sleep(1) token = test_auth_token_exchange["access_token"] - query_params = {"unstable__sort_by": "indexed_on", "unstable__sort_dir": sort_dir} - res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}") + query_params = { + "unstable__sort_by": "indexed_on", + "unstable__sort_dir": sort_dir, + } + with patch( + "api.views.image_views.ImageViewSet.get_db_results" + ) as mock_get_db_result: + mock_get_db_result.side_effect = lambda value: value + + res = client.get( + "/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}" + ) assert res.status_code == 200 res_data = res.json() @@ -232,11 +237,8 @@ def test_sorting_authed( ], ) def test_authority_authed( - client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source + client, test_auth_token_exchange, authority_boost, exp_source ): - # Prevent DB lookup for ES results because DB is empty. - monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False) - time.sleep(1) token = test_auth_token_exchange["access_token"] query_params = { @@ -244,7 +246,14 @@ def test_authority_authed( "unstable__authority": "true", "unstable__authority_boost": authority_boost, } - res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}") + with patch( + "api.views.image_views.ImageViewSet.get_db_results" + ) as mock_get_db_result: + mock_get_db_result.side_effect = lambda value: value + + res = client.get( + "/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}" + ) assert res.status_code == 200 res_data = res.json() diff --git a/api/test/test_dead_link_filter.py b/api/test/test_dead_link_filter.py index bb37d8ed95d..f767023b3c0 100644 --- a/api/test/test_dead_link_filter.py +++ b/api/test/test_dead_link_filter.py @@ -10,20 +10,6 @@ from api.controllers.elasticsearch.helpers import DEAD_LINK_RATIO -@pytest.fixture(autouse=True) -def turn_off_db_read(monkeypatch): - """ - Prevent DB lookup for ES results because DB is empty. - - Since ImageSerializer has set ``needs_db`` to ``True``, all results from ES will be - mapped to DB models. Since the test DB is empty, results array will be empty. By - patching ``needs_db`` to ``False``, we can test the dead link filtering process - without needing to populate the test DB. - """ - - monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False) - - @pytest.fixture def unique_query_hash(redis, monkeypatch): def get_unique_hash(*args, **kwargs): @@ -81,20 +67,24 @@ def test_dead_link_filtering(mocked_map, client): query_params = {"q": "*", "page_size": 20} # Make a request that does not filter dead links... - res_with_dead_links = client.get( - path, - query_params | {"filter_dead": False}, - ) - # ...and ensure that our patched function was not called - mocked_map.assert_not_called() + with patch( + "api.views.image_views.ImageViewSet.get_db_results" + ) as mock_get_db_result: + mock_get_db_result.side_effect = lambda value: value + res_with_dead_links = client.get( + path, + query_params | {"filter_dead": False}, + ) + # ...and ensure that our patched function was not called + mocked_map.assert_not_called() - # Make a request that filters dead links... - res_without_dead_links = client.get( - path, - query_params | {"filter_dead": True}, - ) - # ...and ensure that our patched function was called - mocked_map.assert_called() + # Make a request that filters dead links... + res_without_dead_links = client.get( + path, + query_params | {"filter_dead": True}, + ) + # ...and ensure that our patched function was called + mocked_map.assert_called() assert res_with_dead_links.status_code == 200 assert res_without_dead_links.status_code == 200 @@ -131,11 +121,15 @@ def test_dead_link_filtering_all_dead_links( path = "/v1/images/" query_params = {"q": "*", "page_size": page_size} - with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO): - response = client.get( - path, - query_params | {"filter_dead": filter_dead}, - ) + with patch( + "api.views.image_views.ImageViewSet.get_db_results" + ) as mock_get_db_result: + mock_get_db_result.side_effect = lambda value: value + with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO): + response = client.get( + path, + query_params | {"filter_dead": filter_dead}, + ) assert response.status_code == 200