5
5
from itertools import chain
6
6
from random import random
7
7
from threading import get_ident
8
+ from time import time
8
9
from typing import Any , Callable , List , Optional , overload
9
10
10
11
import msgspec
@@ -123,8 +124,18 @@ async def get_session() -> "DankClientSession":
123
124
return await _get_session_for_thread (get_ident ())
124
125
125
126
127
+ _RETRY_AFTER = 1.0
128
+
129
+
126
130
class DankClientSession (ClientSession ):
127
- async def post (self , endpoint : str , * args , loads : JSONDecoder = DEFAULT_JSON_DECODER , _retry_after : float = 1 , ** kwargs ) -> bytes : # type: ignore [override]
131
+ _limited = False
132
+ _last_rate_limited_at = 0
133
+ _continue_requests_at = 0
134
+
135
+ async def post (self , endpoint : str , * args , loads : JSONDecoder = DEFAULT_JSON_DECODER , ** kwargs ) -> bytes : # type: ignore [override]
136
+ if (now := time ()) < self ._continue_requests_at :
137
+ await asyncio .sleep (self ._continue_requests_at - now )
138
+
128
139
# Process input arguments.
129
140
if isinstance (kwargs .get ("data" ), PartialRequest ):
130
141
logger .debug ("making request for %s" , kwargs ["data" ])
@@ -142,40 +153,56 @@ async def post(self, endpoint: str, *args, loads: JSONDecoder = DEFAULT_JSON_DEC
142
153
return response_data
143
154
except ClientResponseError as ce :
144
155
if ce .status == HTTPStatusExtended .TOO_MANY_REQUESTS : # type: ignore [attr-defined]
145
- try_after = float (ce .headers .get ("Retry-After" , _retry_after * 1.5 )) # type: ignore [union-attr]
146
- if self not in _limited :
147
- _limited .append (self )
148
- logger .info ("You're being rate limited by your node provider" )
149
- logger .info (
150
- "Its all good, dank_mids has this handled, but you might get results slower than you'd like"
151
- )
152
- logger .info (f"rate limited: retrying after { try_after } s" )
153
- await asyncio .sleep (try_after )
154
- if try_after > 30 :
155
- logger .warning ("severe rate limiting from your provider" )
156
- return await self .post (
157
- endpoint , * args , loads = loads , _retry_after = try_after , ** kwargs
156
+ await self .handle_too_many_requests (ce )
157
+ else :
158
+ try :
159
+ if ce .status not in RETRY_FOR_CODES or tried >= 5 :
160
+ logger .debug (
161
+ "response failed with status %s" , HTTPStatusExtended (ce .status )
162
+ )
163
+ raise ce
164
+ except ValueError as ve :
165
+ raise (
166
+ ce if str (ve ).endswith ("is not a valid HTTPStatusExtended" ) else ve
167
+ ) from ve
168
+
169
+ sleep = random ()
170
+ await asyncio .sleep (sleep )
171
+ logger .debug (
172
+ "response failed with status %s, retrying in %ss" ,
173
+ HTTPStatusExtended (ce .status ),
174
+ round (sleep , 2 ),
158
175
)
159
-
160
- try :
161
- if ce .status not in RETRY_FOR_CODES or tried >= 5 :
162
- logger .debug (
163
- "response failed with status %s" , HTTPStatusExtended (ce .status )
164
- )
165
- raise ce
166
- except ValueError as ve :
167
- raise (
168
- ce if str (ve ).endswith ("is not a valid HTTPStatusExtended" ) else ve
169
- ) from ve
170
-
171
- sleep = random ()
172
- await asyncio .sleep (sleep )
173
- logger .debug (
174
- "response failed with status %s, retrying in %ss" ,
175
- HTTPStatusExtended (ce .status ),
176
- round (sleep , 2 ),
177
- )
178
- tried += 1
176
+ tried += 1
177
+
178
+ async def handle_too_many_requests (self , error : ClientResponseError ) -> None :
179
+ now = time ()
180
+ self ._last_rate_limited_at = now
181
+ retry_after = float (error .headers .get ("Retry-After" , _RETRY_AFTER ))
182
+ resume_at = max (
183
+ self ._continue_requests_at + retry_after ,
184
+ self ._last_rate_limited_at + retry_after ,
185
+ )
186
+ retry_after = resume_at - now
187
+ self ._continue_requests_at = resume_at
188
+
189
+ self ._log_rate_limited (retry_after )
190
+ await asyncio .sleep (retry_after )
191
+
192
+ if retry_after > 30 :
193
+ logger .warning ("severe rate limiting from your provider" )
194
+
195
+ def _log_rate_limited (self , try_after : float ) -> None :
196
+ if not self ._limited :
197
+ self ._limited = True
198
+ logger .info ("You're being rate limited by your node provider" )
199
+ logger .info (
200
+ "Its all good, dank_mids has this handled, but you might get results slower than you'd like"
201
+ )
202
+ if try_after < 5 :
203
+ logger .debug ("rate limited: retrying after %.3fs" , try_after )
204
+ else :
205
+ logger .info ("rate limited: retrying after %.3fs" , try_after )
179
206
180
207
181
208
@alru_cache (maxsize = None )
@@ -191,6 +218,3 @@ async def _get_session_for_thread(thread_ident: int) -> DankClientSession:
191
218
raise_for_status = True ,
192
219
read_bufsize = 2 ** 20 , # 1mb
193
220
)
194
-
195
-
196
- _limited : List [DankClientSession ] = []
0 commit comments