Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove needs db and add patch #3687

Merged
merged 17 commits into from
Jan 24, 2024
10 changes: 0 additions & 10 deletions api/api/serializers/audio_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ class AudioCollectionRequestSerializer(PaginatedRequestSerializer):
default=False,
)

@property
def needs_db(self) -> bool:
return super().needs_db or self.data["peaks"]


class AudioSearchRequestSerializer(
AudioSearchRequestSourceSerializer,
Expand Down Expand Up @@ -75,10 +71,6 @@ class AudioSearchRequestSerializer(
default=False,
)

@property
def needs_db(self) -> bool:
return super().needs_db or self.data["peaks"]

def validate_internal__index(self, value):
if not (index := super().validate_internal__index(value)):
return None
Expand Down Expand Up @@ -147,8 +139,6 @@ class Meta:
used to generate Swagger documentation.
"""

needs_db = True # for the 'thumbnail' field

audio_set = AudioSetSerializer(
allow_null=True,
help_text="Reference to set of which this track is a part.",
Expand Down
2 changes: 0 additions & 2 deletions api/api/serializers/image_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ class Meta:
used to generate Swagger documentation.
"""

needs_db = True # for the 'height' and 'width' fields


##########################
# Additional serializers #
Expand Down
7 changes: 0 additions & 7 deletions api/api/serializers/media_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def validate_page_size(self, value):

return value

@property
def needs_db(self) -> bool:
return False


@extend_schema_serializer(
# Hide unstable and internal fields from documentation.
Expand Down Expand Up @@ -497,9 +493,6 @@ class Meta:
used to generate Swagger documentation.
"""

needs_db = False
"""whether the serializer needs fields from the DB to process results"""

id = serializers.CharField(
help_text="Our unique identifier for an open-licensed work.",
source="identifier",
Expand Down
8 changes: 2 additions & 6 deletions api/api/views/media_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,7 @@ def get_media_results(

serializer_context = search_context | self.get_serializer_context()

serializer_class = self.get_serializer()
if params.needs_db or serializer_class.needs_db:
results = self.get_db_results(results)
results = self.get_db_results(results)

serializer = self.get_serializer(results, many=True, context=serializer_context)
return self.get_paginated_response(serializer.data)
Expand Down Expand Up @@ -280,9 +278,7 @@ def related(self, request, identifier=None, *_, **__):

serializer_context = self.get_serializer_context()

serializer_class = self.get_serializer()
if serializer_class.needs_db:
results = self.get_db_results(results)
results = self.get_db_results(results)

serializer = self.get_serializer(results, many=True, context=serializer_context)
return self.get_paginated_response(serializer.data)
Expand Down
37 changes: 23 additions & 14 deletions api/test/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from oauth2_provider.models import AccessToken

from api.models import OAuth2Verification, ThrottledApplication

from unittest.mock import patch

cache_availability_params = pytest.mark.parametrize(
"is_cache_reachable, cache_name",
Expand Down Expand Up @@ -206,16 +206,21 @@ def test_unauthed_response_headers(client):
("asc", "2022-01-01"),
],
)
def test_sorting_authed(
client, monkeypatch, test_auth_token_exchange, sort_dir, exp_indexed_on
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False)

def test_sorting_authed(client, test_auth_token_exchange, sort_dir, exp_indexed_on):
time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {"unstable__sort_by": "indexed_on", "unstable__sort_dir": sort_dir}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
query_params = {
"unstable__sort_by": "indexed_on",
"unstable__sort_dir": sort_dir,
}
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value

res = client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
dhruvkb marked this conversation as resolved.
Show resolved Hide resolved
assert res.status_code == 200

res_data = res.json()
Expand All @@ -232,19 +237,23 @@ def test_sorting_authed(
],
)
def test_authority_authed(
client, monkeypatch, test_auth_token_exchange, authority_boost, exp_source
client, test_auth_token_exchange, authority_boost, exp_source
):
# Prevent DB lookup for ES results because DB is empty.
monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False)

time.sleep(1)
token = test_auth_token_exchange["access_token"]
query_params = {
"q": "cat",
"unstable__authority": "true",
"unstable__authority_boost": authority_boost,
}
res = client.get("/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}")
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value

res = client.get(
"/v1/images/", query_params, HTTP_AUTHORIZATION=f"Bearer {token}"
)
assert res.status_code == 200

res_data = res.json()
Expand Down
58 changes: 26 additions & 32 deletions api/test/test_dead_link_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,6 @@
from api.controllers.elasticsearch.helpers import DEAD_LINK_RATIO


@pytest.fixture(autouse=True)
def turn_off_db_read(monkeypatch):
"""
Prevent DB lookup for ES results because DB is empty.

Since ImageSerializer has set ``needs_db`` to ``True``, all results from ES will be
mapped to DB models. Since the test DB is empty, results array will be empty. By
patching ``needs_db`` to ``False``, we can test the dead link filtering process
without needing to populate the test DB.
"""

monkeypatch.setattr("api.views.image_views.ImageSerializer.needs_db", False)


@pytest.fixture
def unique_query_hash(redis, monkeypatch):
def get_unique_hash(*args, **kwargs):
Expand Down Expand Up @@ -81,20 +67,24 @@ def test_dead_link_filtering(mocked_map, client):
query_params = {"q": "*", "page_size": 20}

# Make a request that does not filter dead links...
res_with_dead_links = client.get(
path,
query_params | {"filter_dead": False},
)
# ...and ensure that our patched function was not called
mocked_map.assert_not_called()
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value
res_with_dead_links = client.get(
path,
query_params | {"filter_dead": False},
)
# ...and ensure that our patched function was not called
mocked_map.assert_not_called()

# Make a request that filters dead links...
res_without_dead_links = client.get(
path,
query_params | {"filter_dead": True},
)
# ...and ensure that our patched function was called
mocked_map.assert_called()
# Make a request that filters dead links...
res_without_dead_links = client.get(
path,
query_params | {"filter_dead": True},
)
# ...and ensure that our patched function was called
mocked_map.assert_called()

assert res_with_dead_links.status_code == 200
assert res_without_dead_links.status_code == 200
Expand Down Expand Up @@ -131,11 +121,15 @@ def test_dead_link_filtering_all_dead_links(
path = "/v1/images/"
query_params = {"q": "*", "page_size": page_size}

with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO):
response = client.get(
path,
query_params | {"filter_dead": filter_dead},
)
with patch(
"api.views.image_views.ImageViewSet.get_db_results"
) as mock_get_db_result:
mock_get_db_result.side_effect = lambda value: value
with patch_link_validation_dead_for_count(page_size / DEAD_LINK_RATIO):
response = client.get(
path,
query_params | {"filter_dead": filter_dead},
)

assert response.status_code == 200

Expand Down