diff --git a/dapr/clients/http/client.py b/dapr/clients/http/client.py index 0d591156..627fddae 100644 --- a/dapr/clients/http/client.py +++ b/dapr/clients/http/client.py @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import aiohttp @@ -24,6 +25,7 @@ CONTENT_TYPE_HEADER, ) from dapr.clients.health import DaprHealth +from dapr.clients.retry import RetryPolicy if TYPE_CHECKING: from dapr.serializers import Serializer @@ -41,6 +43,7 @@ def __init__( message_serializer: 'Serializer', timeout: Optional[int] = 60, headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + retry_policy: Optional[RetryPolicy] = None, ): """Invokes Dapr over HTTP. @@ -49,11 +52,12 @@ def __init__( timeout (int, optional): Timeout in seconds, defaults to 60. headers_callback (lambda: Dict[str, str]], optional): Generates header for each request. """ - DaprHealth.wait_until_ready() + # DaprHealth.wait_until_ready() self._timeout = aiohttp.ClientTimeout(total=timeout) self._serializer = message_serializer self._headers_callback = headers_callback + self.retry_policy = retry_policy or RetryPolicy() async def send_bytes( self, @@ -82,20 +86,53 @@ async def send_bytes( sslcontext = self.get_ssl_context() async with aiohttp.ClientSession(timeout=client_timeout) as session: - r = await session.request( - method=method, - url=url, - data=data, - headers=headers_map, - ssl=sslcontext, - params=query_params, - ) - - if r.status >= 200 and r.status < 300: + req = { + 'method': method, + 'url': url, + 'data': data, + 'headers': headers_map, + 'sslcontext': sslcontext, + 'params': query_params, + } + r = await self.retry_call(session, req) + + if 200 <= r.status < 300: return await r.read(), r raise (await self.convert_to_error(r)) + async def retry_call(self, session, req): + # If max_retries is 0, we don't retry + if self.retry_policy.max_attempts == 0: + return await session.request( + method=req["method"], + url=req["url"], + data=req["data"], + headers=req["headers"], + ssl=req["sslcontext"], + params=req["params"], + ) + + attempt = 0 + while self.retry_policy.max_attempts == -1 or attempt < self.retry_policy.max_attempts: # type: ignore + print(f'Trying RPC call, attempt {attempt + 1}') + r = await session.request(method=req["method"], url=req["url"], data=req["data"], + headers=req["headers"], ssl=req["sslcontext"], params=req["params"], ) + + if r.status not in self.retry_policy.retryable_http_status_codes: + return r + + if self.retry_policy.max_attempts != -1 and attempt == self.retry_policy.max_attempts - 1: # type: ignore + return r + + sleep_time = min(self.retry_policy.max_backoff, + self.retry_policy.initial_backoff * (self.retry_policy.backoff_multiplier ** attempt), ) + + print(f'Sleeping for {sleep_time} seconds before retrying RPC call') + await asyncio.sleep(sleep_time) + attempt += 1 + raise Exception(f'Request failed after {attempt} retries') + async def convert_to_error(self, response: aiohttp.ClientResponse) -> DaprInternalError: error_info = None try: diff --git a/dapr/clients/retry.py b/dapr/clients/retry.py index 7e66f9fd..022eaa76 100644 --- a/dapr/clients/retry.py +++ b/dapr/clients/retry.py @@ -29,7 +29,8 @@ class RetryPolicy: initial_backoff (int): The initial backoff duration. max_backoff (int): The maximum backoff duration. backoff_multiplier (float): The backoff multiplier. - retryable_status_codes (List[StatusCode]): The list of status codes that are retryable. + retryable_http_status_codes (List[int]): The list of http retryable status codes + retryable_grpc_status_codes (List[StatusCode]): The list of retryable grpc status codes """ def __init__( @@ -38,16 +39,35 @@ def __init__( initial_backoff: int = 1, max_backoff: int = 20, backoff_multiplier: float = 1.5, - retryable_status_codes: List[StatusCode] = [ + retryable_http_status_codes: List[int] = [408, 429, 500, 502, 503, 504], + retryable_grpc_status_codes: List[StatusCode] = [ StatusCode.UNAVAILABLE, StatusCode.DEADLINE_EXCEEDED, ], ): + if max_attempts < -1: + raise ValueError('max_attempts must be greater than or equal to -1') self.max_attempts = max_attempts + + if initial_backoff < 1: + raise ValueError('initial_backoff must be greater than or equal to 1') self.initial_backoff = initial_backoff + + if max_backoff < 1: + raise ValueError('max_backoff must be greater than or equal to 1') self.max_backoff = max_backoff + + if backoff_multiplier < 1: + raise ValueError('backoff_multiplier must be greater than or equal to 1') self.backoff_multiplier = backoff_multiplier - self.retryable_status_codes = retryable_status_codes + + if len(retryable_http_status_codes) == 0: + raise ValueError('retryable_http_status_codes can\'t be empty') + self.retryable_http_status_codes = retryable_http_status_codes + + if len(retryable_grpc_status_codes) == 0: + raise ValueError('retryable_http_status_codes can\'t be empty') + self.retryable_grpc_status_codes = retryable_grpc_status_codes def run_rpc_with_retry(policy: RetryPolicy, func=Callable, *args, **kwargs): @@ -61,7 +81,7 @@ def run_rpc_with_retry(policy: RetryPolicy, func=Callable, *args, **kwargs): print(f'Trying RPC call, attempt {attempt + 1}') return func(*args, **kwargs) except RpcError as err: - if err.code() not in policy.retryable_status_codes: + if err.code() not in policy.retryable_grpc_status_codes: raise if policy.max_attempts != -1 and attempt == policy.max_attempts - 1: # type: ignore raise @@ -90,7 +110,7 @@ async def async_run_rpc_with_retry(policy: RetryPolicy, func: Callable, *args, * result = await call return result, call except RpcError as err: - if err.code() not in policy.retryable_status_codes: + if err.code() not in policy.retryable_grpc_status_codes: raise if policy.max_attempts != -1 and attempt == policy.max_attempts - 1: # type: ignore raise diff --git a/tests/clients/test_retries.py b/tests/clients/test_retries.py index 35a50d96..dd017529 100644 --- a/tests/clients/test_retries.py +++ b/tests/clients/test_retries.py @@ -21,16 +21,59 @@ from dapr.clients.retry import RetryPolicy, run_rpc_with_retry -class RetriesTest(unittest.TestCase): - def setUp(self): - self.retry_policy = RetryPolicy( - max_attempts=3, retryable_status_codes=[StatusCode.UNAVAILABLE] +class RetryPolicyTests(unittest.TestCase): + def test_init_success_default(self): + policy = RetryPolicy() + + self.assertEqual(0, policy.max_attempts) + self.assertEqual(1, policy.initial_backoff) + self.assertEqual(20, policy.max_backoff) + self.assertEqual(1.5, policy.backoff_multiplier) + self.assertEqual([408, 429, 500, 502, 503, 504], policy.retryable_http_status_codes) + self.assertEqual([StatusCode.UNAVAILABLE, StatusCode.DEADLINE_EXCEEDED], policy.retryable_grpc_status_codes) + + def test_init_success(self): + policy = RetryPolicy( + max_attempts=3, + initial_backoff=2, + max_backoff=10, + backoff_multiplier=2, + retryable_grpc_status_codes=[StatusCode.UNAVAILABLE], + retryable_http_status_codes=[408, 429] ) + self.assertEqual(3, policy.max_attempts) + self.assertEqual(2, policy.initial_backoff) + self.assertEqual(10, policy.max_backoff) + self.assertEqual(2, policy.backoff_multiplier) + self.assertEqual([StatusCode.UNAVAILABLE], policy.retryable_grpc_status_codes) + self.assertEqual([408, 429], policy.retryable_http_status_codes) + + def test_init_with_errors(self): + with self.assertRaises(ValueError): + RetryPolicy(max_attempts=-2) + + with self.assertRaises(ValueError): + RetryPolicy(initial_backoff=0) + + with self.assertRaises(ValueError): + RetryPolicy(max_backoff=0) + + with self.assertRaises(ValueError): + RetryPolicy(backoff_multiplier=0) + with self.assertRaises(ValueError): + RetryPolicy(retryable_http_status_codes=[]) + + with self.assertRaises(ValueError): + RetryPolicy(retryable_grpc_status_codes=[]) + + +class RetriesTest(unittest.TestCase): def test_run_rpc_with_retry_success(self): mock_func = Mock(return_value='success') - result = run_rpc_with_retry(self.retry_policy, mock_func, 'foo', 'bar', arg1=1, arg2=2) + policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) + result = run_rpc_with_retry(policy, mock_func, 'foo', 'bar', arg1=1, arg2=2) self.assertEqual(result, 'success') mock_func.assert_called_once_with('foo', 'bar', arg1=1, arg2=2) @@ -68,7 +111,9 @@ def test_run_rpc_with_retry_fail_with_another_status_code(self): mock_func = MagicMock(side_effect=mock_error) with self.assertRaises(RpcError): - run_rpc_with_retry(self.retry_policy, mock_func) + policy = RetryPolicy(max_attempts=3, + retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) + run_rpc_with_retry(policy, mock_func) mock_func.assert_called_once() @@ -101,7 +146,7 @@ def test_run_rpc_with_infinite_retries(self, mock_sleep): # Then we assert that the function was called X times before breaking the loop # Configure the policy to simulate infinite retries - policy = RetryPolicy(max_attempts=-1, retryable_status_codes=[StatusCode.UNAVAILABLE]) + policy = RetryPolicy(max_attempts=-1, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) mock_error = RpcError() mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) diff --git a/tests/clients/test_retries_async.py b/tests/clients/test_retries_async.py index 36a64410..44dd8cc2 100644 --- a/tests/clients/test_retries_async.py +++ b/tests/clients/test_retries_async.py @@ -25,7 +25,7 @@ class AsyncRetriesTests(unittest.IsolatedAsyncioTestCase): async def test_run_rpc_with_retry_success(self): mock_func = AsyncMock(return_value='success') - policy = RetryPolicy(max_attempts=3, retryable_status_codes=[StatusCode.UNAVAILABLE]) + policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) result, _ = await async_run_rpc_with_retry(policy, mock_func, 'foo', arg1=1, arg2=2) self.assertEqual(result, 'success') @@ -64,7 +64,7 @@ async def test_run_rpc_with_retry_fail_with_another_status_code(self): mock_func = AsyncMock(side_effect=mock_error) with self.assertRaises(RpcError): - policy = RetryPolicy(max_attempts=3, retryable_status_codes=[StatusCode.UNAVAILABLE]) + policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) await async_run_rpc_with_retry(policy, mock_func) mock_func.assert_awaited_once() @@ -99,7 +99,7 @@ async def test_run_rpc_with_infinite_retries(self, mock_sleep): # Then we assert that the function was called X times before breaking the loop # Configure the policy to simulate infinite retries - policy = RetryPolicy(max_attempts=-1, retryable_status_codes=[StatusCode.UNAVAILABLE]) + policy = RetryPolicy(max_attempts=-1, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) mock_error = RpcError() mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE)