Skip to content

Commit

Permalink
2937 include link to rate limit form in rate limit error messages (#694)
Browse files Browse the repository at this point in the history
* feat: including link to form for rate limit elevenation request in rate limit error message

* feat: add proper error message for rate limit error in lambda handler + rename message field from  to

* fix: test cases

* fix: cleaning up imports in files related to unittest failure in CI

* fix: rename detail -> error in error response body from API for consistency

---------

Co-authored-by: Gerald Iakobinyi-Pich <gerald@gitcoin.co>
  • Loading branch information
nutrina and Gerald Iakobinyi-Pich authored Oct 15, 2024
1 parent 79de019 commit b2c0fb3
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 42 deletions.
41 changes: 39 additions & 2 deletions api/aws_lambdas/tests/test_with_api_request_exception_handling.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
from unittest.mock import patch

import pytest
from registry.test.test_passport_submission import mock_passport
from django_ratelimit.exceptions import Ratelimited

from aws_lambdas.exceptions import InvalidRequest
from aws_lambdas.scorer_api_passport.tests.helpers import MockContext
from aws_lambdas.submit_passport.tests.test_submit_passport_lambda import (
make_test_event,
)
from aws_lambdas.utils import with_api_request_exception_handling
from registry.test.test_passport_submission import mock_passport

pytestmark = pytest.mark.django_db

Expand Down Expand Up @@ -151,3 +152,39 @@ def test_with_api_request_exception_handling_bad_event(

assert ret["statusCode"] == 500
assert ret["body"] == '{"error": "An error has occurred"}'


def test_with_api_request_exception_handling_rate_limit_msg(
scorer_api_key,
scorer_community_with_binary_scorer,
passport_holder_addresses,
mocker,
):
with mocker.patch(
"registry.atasks.aget_passport",
return_value=mock_passport,
):
with mocker.patch(
"registry.atasks.validate_credential", side_effect=[[], [], []]
):
with mocker.patch(
"aws_lambdas.utils.check_rate_limit", side_effect=Ratelimited()
):
with mocker.patch(
"aws_lambdas.utils.get_passport_api_rate_limited_msg",
return_value="You have been rate limited msg: https://link/to/rate/limit/form",
):
wrapped_func = with_api_request_exception_handling(func_to_test)

address = passport_holder_addresses[0]["address"].lower()
test_event = make_test_event(
scorer_api_key, address, scorer_community_with_binary_scorer.id
)

ret = wrapped_func(test_event, MockContext())

assert ret["statusCode"] == 429
assert (
ret["body"]
== '{"error": "You have been rate limited msg: https://link/to/rate/limit/form"}'
)
33 changes: 22 additions & 11 deletions api/aws_lambdas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def load_secrets():
from registry.api.utils import (
ApiKey,
check_rate_limit,
get_analysis_api_rate_limited_msg,
get_passport_api_rate_limited_msg,
save_api_key_analytics,
)
from registry.exceptions import ( # noqa: E402
Expand Down Expand Up @@ -144,7 +146,10 @@ def format_response(ret: Any):


def with_request_exception_handling(func):
@wraps(func)
"""
This wrapper is meant to be used for API handler of the **internal** API, like the ceramic-cache related endpoints
"""

def wrapper(event, context, *args):
try:
bind_contextvars(request_id=context.aws_request_id)
Expand All @@ -157,14 +162,17 @@ def wrapper(event, context, *args):
status = e.status_code
message = str(e.detail)
else:
ratelimit_msg = (
get_analysis_api_rate_limited_msg()
if event.get("path", "").startswith("/passport/")
else get_passport_api_rate_limited_msg()
)

error_descriptions: Dict[Any, Tuple[int, str]] = {
Unauthorized: (403, "Unauthorized"),
InvalidToken: (403, "Invalid token"),
InvalidRequest: (400, "Bad request"),
Ratelimited: (
429,
"You have been rate limited. Please try again later.",
),
Ratelimited: (429, ratelimit_msg),
InterfaceError: (500, "DB Error: InterfaceError"),
DataError: (500, "DB Error: DataError"),
OperationalError: (500, "DB Error: OperationalError"),
Expand All @@ -188,7 +196,7 @@ def wrapper(event, context, *args):
"statusDescription": str(e),
"isBase64Encoded": False,
"headers": RESPONSE_HEADERS,
"body": '{"error": "' + message + '"}',
"body": json.dumps({"error": message}),
}

logger.exception(
Expand Down Expand Up @@ -245,16 +253,19 @@ def wrapper(_event, context):
status = e.status_code
message = str(e.detail)
else:
ratelimit_msg = (
get_analysis_api_rate_limited_msg()
if event.get("path", "").startswith("/passport/")
else get_passport_api_rate_limited_msg()
)

error_descriptions: Dict[Any, Tuple[int, str]] = {
Unauthorized: (403, "Unauthorized"),
InvalidToken: (403, "Invalid token"),
InvalidRequest: (400, "Bad request"),
InvalidAddressException: (400, "Invalid address"),
NotFoundApiException: (400, "Bad request"),
Ratelimited: (
429,
"You have been rate limited. Please try again later.",
),
Ratelimited: (429, ratelimit_msg),
InterfaceError: (500, "DB Error: InterfaceError"),
DataError: (500, "DB Error: DataError"),
OperationalError: (500, "DB Error: OperationalError"),
Expand All @@ -278,7 +289,7 @@ def wrapper(_event, context):
"statusDescription": str(e),
"isBase64Encoded": False,
"headers": RESPONSE_HEADERS,
"body": '{"error": "' + message + '"}',
"body": json.dumps({"error": message}),
}
logger.exception(
"Error occurred with Passport API. Response: %s", json.dumps(response)
Expand Down
6 changes: 5 additions & 1 deletion api/ceramic_cache/management/commands/base_cron_cmds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datetime
import traceback

from django.core.management.base import BaseCommand


Expand All @@ -14,7 +16,9 @@ def handle(self, *args, **options):
self.handle_cron_job(*args, **options)
except Exception as e:
# Handle any exceptions that occur during the command execution
self.stderr.write(f"CRONJOB ERROR: An error occurred: {e}")
self.stderr.write(
f"CRONJOB ERROR: An error occurred: {e}\n{traceback.format_exc()}"
)
finally:
end_time = datetime.datetime.now()
self.log_time("[TIMING] End job", end_time)
Expand Down
9 changes: 7 additions & 2 deletions api/passport/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from ninja_extra.exceptions import APIException

import api_logging as logging
from registry.api.utils import aapi_key, check_rate_limit, is_valid_address
from registry.api.utils import (
aapi_key,
check_rate_limit,
get_analysis_api_rate_limited_msg,
is_valid_address,
)
from registry.exceptions import InvalidAddressException
from scorer.settings.model_config import MODEL_AGGREGATION_NAMES

Expand Down Expand Up @@ -44,7 +49,7 @@ class DetailedScoreModel(Schema):
def service_unavailable(request, _):
return api.create_response(
request,
{"detail": "You have been rate limited!"},
{"error": get_analysis_api_rate_limited_msg()},
status=429,
)

Expand Down
21 changes: 16 additions & 5 deletions api/passport/test/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,22 @@ def test_rate_limit_is_applied(self):
client = Client()

with patch("registry.api.utils.is_ratelimited", return_value=True):
response = client.get(
"/passport/analysis/0x06e3c221011767FE816D0B8f5B16253E43e4Af7D",
**self.headers,
)
assert response.status_code == 429
with patch(
"registry.api.utils.MBD_API_RATE_LIMITING_FORM",
"https://link/to/rate/limit/form",
):
response = client.get(
"/passport/analysis/0x06e3c221011767FE816D0B8f5B16253E43e4Af7D",
**self.headers,
)
assert response.status_code == 429

data = response.json()

assert (
data["error"]
== "You have been rate limited! Use this form to request a rate limit elevation: https://link/to/rate/limit/form"
)

@patch("passport.api.fetch", side_effect=mock_post_response)
def test_checksummed_address_is_passed_on(self, mock_post):
Expand Down
16 changes: 14 additions & 2 deletions api/registry/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import functools

import api_logging as logging
from account.models import Account, AccountAPIKey
from django.conf import settings
from django.contrib.auth import get_user_model
from django.utils.module_loading import import_string
Expand All @@ -16,13 +14,27 @@
from ninja.compatibility.request import get_headers
from ninja.security import APIKeyHeader
from ninja.security.base import SecuritySchema

import api_logging as logging
from account.models import Account, AccountAPIKey
from registry.api.schema import SubmitPassportPayload
from registry.atasks import asave_api_key_analytics
from registry.exceptions import InvalidScorerIdException, Unauthorized
from registry.tasks import save_api_key_analytics

log = logging.getLogger(__name__)

PASSPORT_API_RATE_LIMITING_FORM = settings.PASSPORT_API_RATE_LIMITING_FORM
MBD_API_RATE_LIMITING_FORM = settings.MBD_API_RATE_LIMITING_FORM


def get_passport_api_rate_limited_msg() -> str:
return f"You have been rate limited! Use this form to request a rate limit elevation: {PASSPORT_API_RATE_LIMITING_FORM}"


def get_analysis_api_rate_limited_msg() -> str:
return f"You have been rate limited! Use this form to request a rate limit elevation: {MBD_API_RATE_LIMITING_FORM}"


def atrack_apikey_usage(track_response=True, payload_param_name=None):
def decorator_track_apikey_usage(func):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import asyncio
import csv
import json
import os
import time
from io import BytesIO, StringIO, TextIOWrapper
from itertools import islice
from nis import cat

import boto3
from asgiref.sync import sync_to_async
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from eth_utils.address import to_checksum_address
from tqdm import tqdm

from passport.api import fetch_all, handle_get_analysis
from passport.api import handle_get_analysis
from registry.admin import get_s3_client
from registry.models import BatchModelScoringRequest, BatchRequestStatus
from scorer.settings import (
Expand Down Expand Up @@ -42,9 +37,9 @@ async def async_handle(self, *args, **options):
s3_uri = f"s3://{S3_BUCKET}/{S3_OBJECT_KEY}"

# Find the request id from the filename.
filename = S3_OBJECT_KEY.split(
f"{BULK_SCORE_REQUESTS_ADDRESS_LIST_FOLDER}/"
)[-1]
filename = S3_OBJECT_KEY.split(f"{BULK_SCORE_REQUESTS_ADDRESS_LIST_FOLDER}/")[
-1
]
self.stdout.write(f"Search request with filename: `{filename}`")

request = await sync_to_async(BatchModelScoringRequest.objects.get)(
Expand Down Expand Up @@ -72,7 +67,9 @@ async def async_handle(self, *args, **options):
total_rows = sum(1 for row in csv_data)
else:
# The first row is not a header, so include it in the processing
total_rows = 1 + sum(1 for row in csv_data) # Adding the first row already read
total_rows = 1 + sum(
1 for row in csv_data
) # Adding the first row already read

# Reset the reader to the start of the file or just after the header
text.seek(0)
Expand Down
10 changes: 5 additions & 5 deletions api/registry/test/test_command_process_batch_address_upload.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import asyncio
from io import StringIO
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from django.core.management import call_command
from django.test import TransactionTestCase, override_settings
from django.test import TransactionTestCase

from registry.models import BatchModelScoringRequest, BatchRequestStatus

Expand Down Expand Up @@ -41,8 +39,10 @@ def test_process_pending_requests(self):
)

call_command("process_batch_model_address_upload")

updated_request = BatchModelScoringRequest.objects.get(id=good_request.id)

updated_request = BatchModelScoringRequest.objects.get(
id=good_request.id
)
self.assertEqual(
updated_request.status,
BatchRequestStatus.DONE.value,
Expand Down
49 changes: 47 additions & 2 deletions api/registry/test/test_ratelimit.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from unittest.mock import patch

import api_logging as logging
import pytest
from account.models import AccountAPIKey, RateLimits
from django.test import Client, override_settings

import api_logging as logging
from account.models import AccountAPIKey, RateLimits

pytestmark = pytest.mark.django_db

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,6 +142,50 @@ def test_rate_limit_is_applied(scorer_api_key, api_path_that_requires_rate_limit
assert response.status_code == 429


@override_settings(RATELIMIT_ENABLE=True)
def test_rate_limit_msg_contains_form_link(
scorer_api_key, api_path_that_requires_rate_limit
):
"""
Test that the error message conatins the proper form link.
"""
with patch(
"registry.api.utils.PASSPORT_API_RATE_LIMITING_FORM",
"https://link/to/rate/limit/form",
):
method, path, payload = api_path_that_requires_rate_limit
client = Client()

with patch("registry.api.utils.is_ratelimited", return_value=True):
if method == "post":
response = client.post(
path,
json.dumps(payload),
**{
"content_type": "application/json",
"HTTP_AUTHORIZATION": f"Token {scorer_api_key}",
},
)

data = response.json()
assert response.status_code == 429
assert (
data["error"]
== "You have been rate limited! Use this form to request a rate limit elevation: https://link/to/rate/limit/form"
)
else:
response = client.get(
path,
HTTP_AUTHORIZATION=f"Token {scorer_api_key}",
)
data = response.json()
assert response.status_code == 429
assert (
data["error"]
== "You have been rate limited! Use this form to request a rate limit elevation: https://link/to/rate/limit/form"
)


@override_settings(RATELIMIT_ENABLE=True)
def test_no_rate_limit_for_none(unlimited_scorer_api_key):
"""
Expand Down
Loading

0 comments on commit b2c0fb3

Please sign in to comment.