Skip to content

Commit a3c60aa

Browse files
feat: better rate limiting logic (#317)
* feat: better rate-limiting logic * fix: if -> while * feat: better rate-limiting logic * fix: suppress StopIteration warning * chore: round log msg to 3 decimals * chore: `black .` * chore: suppress spammy logs --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 4ae7575 commit a3c60aa

File tree

2 files changed

+62
-38
lines changed

2 files changed

+62
-38
lines changed

dank_mids/brownie_patch/call.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
*args: The arguments to be encoded.
6565
"""
6666

67-
# We do this so ypricemagic's checksum cache monkey patch will work,
67+
# We assign this variable so ypricemagic's checksum cache monkey patch will work,
6868
# This is only relevant to you if your project uses ypricemagic as well.
6969
to_checksum_address = Address.checksum
7070

dank_mids/helpers/_session.py

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from itertools import chain
66
from random import random
77
from threading import get_ident
8+
from time import time
89
from typing import Any, Callable, List, Optional, overload
910

1011
import msgspec
@@ -123,8 +124,18 @@ async def get_session() -> "DankClientSession":
123124
return await _get_session_for_thread(get_ident())
124125

125126

127+
_RETRY_AFTER = 1.0
128+
129+
126130
class DankClientSession(ClientSession):
127-
async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, _retry_after: float = 1, **kwargs) -> bytes: # type: ignore [override]
131+
_limited = False
132+
_last_rate_limited_at = 0
133+
_continue_requests_at = 0
134+
135+
async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DECODER, **kwargs) -> bytes: # type: ignore [override]
136+
if (now := time()) < self._continue_requests_at:
137+
await asyncio.sleep(self._continue_requests_at - now)
138+
128139
# Process input arguments.
129140
if isinstance(kwargs.get("data"), PartialRequest):
130141
logger.debug("making request for %s", kwargs["data"])
@@ -142,40 +153,56 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC
142153
return response_data
143154
except ClientResponseError as ce:
144155
if ce.status == HTTPStatusExtended.TOO_MANY_REQUESTS: # type: ignore [attr-defined]
145-
try_after = float(ce.headers.get("Retry-After", _retry_after * 1.5)) # type: ignore [union-attr]
146-
if self not in _limited:
147-
_limited.append(self)
148-
logger.info("You're being rate limited by your node provider")
149-
logger.info(
150-
"Its all good, dank_mids has this handled, but you might get results slower than you'd like"
151-
)
152-
logger.info(f"rate limited: retrying after {try_after}s")
153-
await asyncio.sleep(try_after)
154-
if try_after > 30:
155-
logger.warning("severe rate limiting from your provider")
156-
return await self.post(
157-
endpoint, *args, loads=loads, _retry_after=try_after, **kwargs
156+
await self.handle_too_many_requests(ce)
157+
else:
158+
try:
159+
if ce.status not in RETRY_FOR_CODES or tried >= 5:
160+
logger.debug(
161+
"response failed with status %s", HTTPStatusExtended(ce.status)
162+
)
163+
raise ce
164+
except ValueError as ve:
165+
raise (
166+
ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve
167+
) from ve
168+
169+
sleep = random()
170+
await asyncio.sleep(sleep)
171+
logger.debug(
172+
"response failed with status %s, retrying in %ss",
173+
HTTPStatusExtended(ce.status),
174+
round(sleep, 2),
158175
)
159-
160-
try:
161-
if ce.status not in RETRY_FOR_CODES or tried >= 5:
162-
logger.debug(
163-
"response failed with status %s", HTTPStatusExtended(ce.status)
164-
)
165-
raise ce
166-
except ValueError as ve:
167-
raise (
168-
ce if str(ve).endswith("is not a valid HTTPStatusExtended") else ve
169-
) from ve
170-
171-
sleep = random()
172-
await asyncio.sleep(sleep)
173-
logger.debug(
174-
"response failed with status %s, retrying in %ss",
175-
HTTPStatusExtended(ce.status),
176-
round(sleep, 2),
177-
)
178-
tried += 1
176+
tried += 1
177+
178+
async def handle_too_many_requests(self, error: ClientResponseError) -> None:
179+
now = time()
180+
self._last_rate_limited_at = now
181+
retry_after = float(error.headers.get("Retry-After", _RETRY_AFTER))
182+
resume_at = max(
183+
self._continue_requests_at + retry_after,
184+
self._last_rate_limited_at + retry_after,
185+
)
186+
retry_after = resume_at - now
187+
self._continue_requests_at = resume_at
188+
189+
self._log_rate_limited(retry_after)
190+
await asyncio.sleep(retry_after)
191+
192+
if retry_after > 30:
193+
logger.warning("severe rate limiting from your provider")
194+
195+
def _log_rate_limited(self, try_after: float) -> None:
196+
if not self._limited:
197+
self._limited = True
198+
logger.info("You're being rate limited by your node provider")
199+
logger.info(
200+
"Its all good, dank_mids has this handled, but you might get results slower than you'd like"
201+
)
202+
if try_after < 5:
203+
logger.debug("rate limited: retrying after %.3fs", try_after)
204+
else:
205+
logger.info("rate limited: retrying after %.3fs", try_after)
179206

180207

181208
@alru_cache(maxsize=None)
@@ -191,6 +218,3 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession:
191218
raise_for_status=True,
192219
read_bufsize=2**20, # 1mb
193220
)
194-
195-
196-
_limited: List[DankClientSession] = []

0 commit comments

Comments
 (0)