Skip to content

Commit

Permalink
retry for http wip.
Browse files Browse the repository at this point in the history
Signed-off-by: Elena Kolevska <elena@kolevska.com>
  • Loading branch information
elena-kolevska committed Mar 2, 2024
1 parent a543575 commit ad13b66
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 26 deletions.
59 changes: 48 additions & 11 deletions dapr/clients/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio

import aiohttp

Expand All @@ -24,6 +25,7 @@
CONTENT_TYPE_HEADER,
)
from dapr.clients.health import DaprHealth
from dapr.clients.retry import RetryPolicy

if TYPE_CHECKING:
from dapr.serializers import Serializer
Expand All @@ -41,6 +43,7 @@ def __init__(
message_serializer: 'Serializer',
timeout: Optional[int] = 60,
headers_callback: Optional[Callable[[], Dict[str, str]]] = None,
retry_policy: Optional[RetryPolicy] = None,
):
"""Invokes Dapr over HTTP.
Expand All @@ -49,11 +52,12 @@ def __init__(
timeout (int, optional): Timeout in seconds, defaults to 60.
headers_callback (lambda: Dict[str, str]], optional): Generates header for each request.
"""
DaprHealth.wait_until_ready()
# DaprHealth.wait_until_ready()

self._timeout = aiohttp.ClientTimeout(total=timeout)
self._serializer = message_serializer
self._headers_callback = headers_callback
self.retry_policy = retry_policy or RetryPolicy()

async def send_bytes(
self,
Expand Down Expand Up @@ -82,20 +86,53 @@ async def send_bytes(
sslcontext = self.get_ssl_context()

async with aiohttp.ClientSession(timeout=client_timeout) as session:
r = await session.request(
method=method,
url=url,
data=data,
headers=headers_map,
ssl=sslcontext,
params=query_params,
)

if r.status >= 200 and r.status < 300:
req = {
'method': method,
'url': url,
'data': data,
'headers': headers_map,
'sslcontext': sslcontext,
'params': query_params,
}
r = await self.retry_call(session, req)

if 200 <= r.status < 300:
return await r.read(), r

raise (await self.convert_to_error(r))

async def retry_call(self, session, req):
# If max_retries is 0, we don't retry
if self.retry_policy.max_attempts == 0:
return await session.request(
method=req["method"],
url=req["url"],
data=req["data"],
headers=req["headers"],
ssl=req["sslcontext"],
params=req["params"],
)

attempt = 0
while self.retry_policy.max_attempts == -1 or attempt < self.retry_policy.max_attempts: # type: ignore
print(f'Trying RPC call, attempt {attempt + 1}')
r = await session.request(method=req["method"], url=req["url"], data=req["data"],
headers=req["headers"], ssl=req["sslcontext"], params=req["params"], )

if r.status not in self.retry_policy.retryable_http_status_codes:
return r

if self.retry_policy.max_attempts != -1 and attempt == self.retry_policy.max_attempts - 1: # type: ignore
return r

sleep_time = min(self.retry_policy.max_backoff,
self.retry_policy.initial_backoff * (self.retry_policy.backoff_multiplier ** attempt), )

print(f'Sleeping for {sleep_time} seconds before retrying RPC call')
await asyncio.sleep(sleep_time)
attempt += 1
raise Exception(f'Request failed after {attempt} retries')

async def convert_to_error(self, response: aiohttp.ClientResponse) -> DaprInternalError:
error_info = None
try:
Expand Down
30 changes: 25 additions & 5 deletions dapr/clients/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class RetryPolicy:
initial_backoff (int): The initial backoff duration.
max_backoff (int): The maximum backoff duration.
backoff_multiplier (float): The backoff multiplier.
retryable_status_codes (List[StatusCode]): The list of status codes that are retryable.
retryable_http_status_codes (List[int]): The list of http retryable status codes
retryable_grpc_status_codes (List[StatusCode]): The list of retryable grpc status codes
"""

def __init__(
Expand All @@ -38,16 +39,35 @@ def __init__(
initial_backoff: int = 1,
max_backoff: int = 20,
backoff_multiplier: float = 1.5,
retryable_status_codes: List[StatusCode] = [
retryable_http_status_codes: List[int] = [408, 429, 500, 502, 503, 504],
retryable_grpc_status_codes: List[StatusCode] = [
StatusCode.UNAVAILABLE,
StatusCode.DEADLINE_EXCEEDED,
],
):
if max_attempts < -1:
raise ValueError('max_attempts must be greater than or equal to -1')
self.max_attempts = max_attempts

if initial_backoff < 1:
raise ValueError('initial_backoff must be greater than or equal to 1')
self.initial_backoff = initial_backoff

if max_backoff < 1:
raise ValueError('max_backoff must be greater than or equal to 1')
self.max_backoff = max_backoff

if backoff_multiplier < 1:
raise ValueError('backoff_multiplier must be greater than or equal to 1')
self.backoff_multiplier = backoff_multiplier
self.retryable_status_codes = retryable_status_codes

if len(retryable_http_status_codes) == 0:
raise ValueError('retryable_http_status_codes can\'t be empty')
self.retryable_http_status_codes = retryable_http_status_codes

if len(retryable_grpc_status_codes) == 0:
raise ValueError('retryable_http_status_codes can\'t be empty')
self.retryable_grpc_status_codes = retryable_grpc_status_codes


def run_rpc_with_retry(policy: RetryPolicy, func=Callable, *args, **kwargs):
Expand All @@ -61,7 +81,7 @@ def run_rpc_with_retry(policy: RetryPolicy, func=Callable, *args, **kwargs):
print(f'Trying RPC call, attempt {attempt + 1}')
return func(*args, **kwargs)
except RpcError as err:
if err.code() not in policy.retryable_status_codes:
if err.code() not in policy.retryable_grpc_status_codes:
raise
if policy.max_attempts != -1 and attempt == policy.max_attempts - 1: # type: ignore
raise
Expand Down Expand Up @@ -90,7 +110,7 @@ async def async_run_rpc_with_retry(policy: RetryPolicy, func: Callable, *args, *
result = await call
return result, call
except RpcError as err:
if err.code() not in policy.retryable_status_codes:
if err.code() not in policy.retryable_grpc_status_codes:
raise
if policy.max_attempts != -1 and attempt == policy.max_attempts - 1: # type: ignore
raise
Expand Down
59 changes: 52 additions & 7 deletions tests/clients/test_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,59 @@
from dapr.clients.retry import RetryPolicy, run_rpc_with_retry


class RetriesTest(unittest.TestCase):
def setUp(self):
self.retry_policy = RetryPolicy(
max_attempts=3, retryable_status_codes=[StatusCode.UNAVAILABLE]
class RetryPolicyTests(unittest.TestCase):
def test_init_success_default(self):
policy = RetryPolicy()

self.assertEqual(0, policy.max_attempts)
self.assertEqual(1, policy.initial_backoff)
self.assertEqual(20, policy.max_backoff)
self.assertEqual(1.5, policy.backoff_multiplier)
self.assertEqual([408, 429, 500, 502, 503, 504], policy.retryable_http_status_codes)
self.assertEqual([StatusCode.UNAVAILABLE, StatusCode.DEADLINE_EXCEEDED], policy.retryable_grpc_status_codes)

def test_init_success(self):
policy = RetryPolicy(
max_attempts=3,
initial_backoff=2,
max_backoff=10,
backoff_multiplier=2,
retryable_grpc_status_codes=[StatusCode.UNAVAILABLE],
retryable_http_status_codes=[408, 429]
)
self.assertEqual(3, policy.max_attempts)
self.assertEqual(2, policy.initial_backoff)
self.assertEqual(10, policy.max_backoff)
self.assertEqual(2, policy.backoff_multiplier)
self.assertEqual([StatusCode.UNAVAILABLE], policy.retryable_grpc_status_codes)
self.assertEqual([408, 429], policy.retryable_http_status_codes)

def test_init_with_errors(self):
with self.assertRaises(ValueError):
RetryPolicy(max_attempts=-2)

with self.assertRaises(ValueError):
RetryPolicy(initial_backoff=0)

with self.assertRaises(ValueError):
RetryPolicy(max_backoff=0)

with self.assertRaises(ValueError):
RetryPolicy(backoff_multiplier=0)

with self.assertRaises(ValueError):
RetryPolicy(retryable_http_status_codes=[])

with self.assertRaises(ValueError):
RetryPolicy(retryable_grpc_status_codes=[])


class RetriesTest(unittest.TestCase):
def test_run_rpc_with_retry_success(self):
mock_func = Mock(return_value='success')

result = run_rpc_with_retry(self.retry_policy, mock_func, 'foo', 'bar', arg1=1, arg2=2)
policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE])
result = run_rpc_with_retry(policy, mock_func, 'foo', 'bar', arg1=1, arg2=2)

self.assertEqual(result, 'success')
mock_func.assert_called_once_with('foo', 'bar', arg1=1, arg2=2)
Expand Down Expand Up @@ -68,7 +111,9 @@ def test_run_rpc_with_retry_fail_with_another_status_code(self):
mock_func = MagicMock(side_effect=mock_error)

with self.assertRaises(RpcError):
run_rpc_with_retry(self.retry_policy, mock_func)
policy = RetryPolicy(max_attempts=3,
retryable_grpc_status_codes=[StatusCode.UNAVAILABLE])
run_rpc_with_retry(policy, mock_func)

mock_func.assert_called_once()

Expand Down Expand Up @@ -101,7 +146,7 @@ def test_run_rpc_with_infinite_retries(self, mock_sleep):
# Then we assert that the function was called X times before breaking the loop

# Configure the policy to simulate infinite retries
policy = RetryPolicy(max_attempts=-1, retryable_status_codes=[StatusCode.UNAVAILABLE])
policy = RetryPolicy(max_attempts=-1, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE])

mock_error = RpcError()
mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE)
Expand Down
6 changes: 3 additions & 3 deletions tests/clients/test_retries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AsyncRetriesTests(unittest.IsolatedAsyncioTestCase):
async def test_run_rpc_with_retry_success(self):
mock_func = AsyncMock(return_value='success')

policy = RetryPolicy(max_attempts=3, retryable_status_codes=[StatusCode.UNAVAILABLE])
policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE])
result, _ = await async_run_rpc_with_retry(policy, mock_func, 'foo', arg1=1, arg2=2)

self.assertEqual(result, 'success')
Expand Down Expand Up @@ -64,7 +64,7 @@ async def test_run_rpc_with_retry_fail_with_another_status_code(self):
mock_func = AsyncMock(side_effect=mock_error)

with self.assertRaises(RpcError):
policy = RetryPolicy(max_attempts=3, retryable_status_codes=[StatusCode.UNAVAILABLE])
policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE])
await async_run_rpc_with_retry(policy, mock_func)

mock_func.assert_awaited_once()
Expand Down Expand Up @@ -99,7 +99,7 @@ async def test_run_rpc_with_infinite_retries(self, mock_sleep):
# Then we assert that the function was called X times before breaking the loop

# Configure the policy to simulate infinite retries
policy = RetryPolicy(max_attempts=-1, retryable_status_codes=[StatusCode.UNAVAILABLE])
policy = RetryPolicy(max_attempts=-1, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE])

mock_error = RpcError()
mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE)
Expand Down

0 comments on commit ad13b66

Please sign in to comment.