diff --git a/dapr/actor/client/proxy.py b/dapr/actor/client/proxy.py index fd62d271..a7648bf9 100644 --- a/dapr/actor/client/proxy.py +++ b/dapr/actor/client/proxy.py @@ -20,6 +20,7 @@ from dapr.actor.id import ActorId from dapr.actor.runtime._type_utils import get_dispatchable_attrs_from_interface from dapr.clients import DaprActorClientBase, DaprActorHttpClient +from dapr.clients.retry import RetryPolicy from dapr.serializers import Serializer, DefaultJSONSerializer from dapr.conf import settings @@ -50,9 +51,12 @@ def __init__( self, message_serializer=DefaultJSONSerializer(), http_timeout_seconds: int = settings.DAPR_HTTP_TIMEOUT_SECONDS, + retry_policy: Optional[RetryPolicy] = None, ): # TODO: support serializer for state store later - self._dapr_client = DaprActorHttpClient(message_serializer, timeout=http_timeout_seconds) + self._dapr_client = DaprActorHttpClient( + message_serializer, timeout=http_timeout_seconds, retry_policy=retry_policy + ) self._message_serializer = message_serializer def create( diff --git a/dapr/aio/clients/grpc/client.py b/dapr/aio/clients/grpc/client.py index b715dbb3..e1e4f02c 100644 --- a/dapr/aio/clients/grpc/client.py +++ b/dapr/aio/clients/grpc/client.py @@ -43,13 +43,17 @@ from dapr.clients.grpc._state import StateOptions, StateItem from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus from dapr.clients.health import DaprHealth +from dapr.clients.retry import RetryPolicy from dapr.conf.helpers import GrpcEndpoint from dapr.conf import settings from dapr.proto import api_v1, api_service_v1, common_v1 from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse from dapr.version import __version__ -from dapr.aio.clients.grpc._asynchelpers import DaprClientInterceptorAsync +from dapr.aio.clients.grpc.interceptors import ( + DaprClientInterceptorAsync, + DaprClientTimeoutInterceptorAsync, +) from dapr.clients.grpc._helpers import ( MetadataTuple, to_bytes, @@ -118,6 +122,7 @@ def __init__( ] ] = None, max_grpc_message_length: Optional[int] = None, + retry_policy: Optional[RetryPolicy] = None, ): """Connects to Dapr Runtime and initialize gRPC client stub. @@ -131,6 +136,7 @@ def __init__( message length in bytes. """ DaprHealth.wait_until_ready() + self.retry_policy = retry_policy or RetryPolicy() useragent = f'dapr-sdk-python/{__version__}' if not max_grpc_message_length: @@ -154,12 +160,11 @@ def __init__( except ValueError as error: raise DaprInternalError(f'{error}') from error - if self._uri.tls: - self._channel = grpc.aio.secure_channel( - self._uri.endpoint, credentials=self.get_credentials(), options=options - ) # type: ignore + # Prepare interceptors + if interceptors is None: + interceptors = [DaprClientTimeoutInterceptorAsync()] else: - self._channel = grpc.aio.insecure_channel(self._uri.endpoint, options) # type: ignore + interceptors.append(DaprClientTimeoutInterceptorAsync()) if settings.DAPR_API_TOKEN: api_token_interceptor = DaprClientInterceptorAsync( @@ -167,13 +172,20 @@ def __init__( ('dapr-api-token', settings.DAPR_API_TOKEN), ] ) - self._channel = grpc.aio.insecure_channel( # type: ignore - address, options=options, interceptors=(api_token_interceptor,) - ) - if interceptors: - self._channel = grpc.aio.insecure_channel( # type: ignore - address, options=options, *interceptors - ) + interceptors.append(api_token_interceptor) + + # Create gRPC channel + if self._uri.tls: + self._channel = grpc.aio.secure_channel( + self._uri.endpoint, + credentials=self.get_credentials(), + options=options, + interceptors=interceptors, + ) # type: ignore + else: + self._channel = grpc.aio.insecure_channel( + self._uri.endpoint, options, interceptors=interceptors + ) # type: ignore self._stub = api_service_v1.DaprStub(self._channel) @@ -713,8 +725,9 @@ async def save_state( req = api_v1.SaveStateRequest(store_name=store_name, states=[state]) try: - call = self._stub.SaveState(req, metadata=metadata) - await call + result, call = await self.retry_policy.run_rpc_async( + self._stub.SaveState, req, metadata=metadata + ) return DaprResponse(headers=await call.initial_metadata()) except AioRpcError as e: raise DaprInternalError(e.details()) from e diff --git a/dapr/aio/clients/grpc/_asynchelpers.py b/dapr/aio/clients/grpc/interceptors.py similarity index 83% rename from dapr/aio/clients/grpc/_asynchelpers.py rename to dapr/aio/clients/grpc/interceptors.py index 32bfe39f..55ede4b9 100644 --- a/dapr/aio/clients/grpc/_asynchelpers.py +++ b/dapr/aio/clients/grpc/interceptors.py @@ -18,6 +18,8 @@ from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails # type: ignore +from dapr.conf import settings + class _ClientCallDetailsAsync( namedtuple( @@ -33,6 +35,22 @@ class _ClientCallDetailsAsync( pass +class DaprClientTimeoutInterceptorAsync(UnaryUnaryClientInterceptor): + def intercept_unary_unary(self, continuation, client_call_details, request): + # If a specific timeout is not set, create a new ClientCallDetails with the default timeout + if client_call_details.timeout is None: + new_client_call_details = _ClientCallDetailsAsync( + client_call_details.method, + settings.DAPR_API_TIMEOUT_SECONDS, + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + ) + return continuation(new_client_call_details, request) + + return continuation(client_call_details, request) + + class DaprClientInterceptorAsync(UnaryUnaryClientInterceptor): """The class implements a UnaryUnaryClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed. diff --git a/dapr/clients/__init__.py b/dapr/clients/__init__.py index b39124b0..20c14785 100644 --- a/dapr/clients/__init__.py +++ b/dapr/clients/__init__.py @@ -21,6 +21,7 @@ from dapr.clients.grpc.client import DaprGrpcClient, MetadataTuple, InvokeMethodResponse from dapr.clients.http.dapr_actor_http_client import DaprActorHttpClient from dapr.clients.http.dapr_invocation_http_client import DaprInvocationHttpClient +from dapr.clients.retry import RetryPolicy from dapr.conf import settings from google.protobuf.message import Message as GrpcMessage @@ -64,6 +65,7 @@ def __init__( ] = None, http_timeout_seconds: Optional[int] = None, max_grpc_message_length: Optional[int] = None, + retry_policy: Optional[RetryPolicy] = None, ): """Connects to Dapr Runtime via gRPC and HTTP. @@ -78,7 +80,7 @@ def __init__( max_grpc_message_length (int, optional): The maximum grpc send and receive message length in bytes. """ - super().__init__(address, interceptors, max_grpc_message_length) + super().__init__(address, interceptors, max_grpc_message_length, retry_policy) self.invocation_client = None invocation_protocol = settings.DAPR_API_METHOD_INVOCATION_PROTOCOL.upper() diff --git a/dapr/clients/grpc/_helpers.py b/dapr/clients/grpc/_helpers.py index 2f9f6006..7da35dc2 100644 --- a/dapr/clients/grpc/_helpers.py +++ b/dapr/clients/grpc/_helpers.py @@ -12,12 +12,10 @@ See the License for the specific language governing permissions and limitations under the License. """ -from collections import namedtuple from typing import Dict, List, Union, Tuple, Optional from enum import Enum from google.protobuf.any_pb2 import Any as GrpcAny from google.protobuf.message import Message as GrpcMessage -from grpc import UnaryUnaryClientInterceptor, ClientCallDetails # type: ignore MetadataDict = Dict[str, List[Union[bytes, str]]] MetadataTuple = Tuple[Tuple[str, Union[bytes, str]], ...] @@ -78,96 +76,7 @@ def to_str(data: Union[str, bytes]) -> str: raise f'invalid data type {type(data)}' -class _ClientCallDetails( - namedtuple( - '_ClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'], - ), - ClientCallDetails, -): - """This is an implementation of the ClientCallDetails interface needed for interceptors. - This class takes six named values and inherits the ClientCallDetails from grpc package. - This class encloses the values that describe a RPC to be invoked. - """ - - pass - - -class DaprClientInterceptor(UnaryUnaryClientInterceptor): - """The class implements a UnaryUnaryClientInterceptor from grpc to add an interceptor to add - additional headers to all calls as needed. - - Examples: - - interceptor = HeaderInterceptor([('header', 'value', )]) - intercepted_channel = grpc.intercept_channel(grpc_channel, interceptor) - - With multiple header values: - - interceptor = HeaderInterceptor([('header1', 'value1', ), ('header2', 'value2', )]) - intercepted_channel = grpc.intercept_channel(grpc_channel, interceptor) - """ - - def __init__(self, metadata: List[Tuple[str, str]]): - """Initializes the metadata field for the class. - - Args: - metadata list[tuple[str, str]]: list of tuple of (key, value) strings - representing header values - """ - - self._metadata = metadata - - def _intercept_call(self, client_call_details: ClientCallDetails) -> ClientCallDetails: - """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC - call details. - - Args: - client_call_details :class: `ClientCallDetails`: object that describes a RPC - to be invoked - - Returns: - :class: `ClientCallDetails` modified call details - """ - - metadata = [] - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - metadata.extend(self._metadata) - - new_call_details = _ClientCallDetails( - client_call_details.method, - client_call_details.timeout, - metadata, - client_call_details.credentials, - client_call_details.wait_for_ready, - client_call_details.compression, - ) - return new_call_details - - def intercept_unary_unary(self, continuation, client_call_details, request): - """This method intercepts a unary-unary gRPC call. This is the implementation of the - abstract method defined in UnaryUnaryClientInterceptor defined in grpc. This is invoked - automatically by grpc based on the order in which interceptors are added to the channel. - - Args: - continuation: a callable to be invoked to continue with the RPC or next interceptor - client_call_details: a ClientCallDetails object describing the outgoing RPC - request: the request value for the RPC - - Returns: - A response object after invoking the continuation callable - """ - # Pre-process or intercept call - new_call_details = self._intercept_call(client_call_details) - # Call continuation - response = continuation(new_call_details, request) - return response - - # Data validation helpers - - def validateNotNone(**kwargs: Optional[str]): for field_name, value in kwargs.items(): if value is None: diff --git a/dapr/clients/grpc/client.py b/dapr/clients/grpc/client.py index 4b0dc064..53d9016e 100644 --- a/dapr/clients/grpc/client.py +++ b/dapr/clients/grpc/client.py @@ -40,14 +40,15 @@ from dapr.clients.exceptions import DaprInternalError, DaprGrpcError from dapr.clients.grpc._state import StateOptions, StateItem from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus +from dapr.clients.grpc.interceptors import DaprClientInterceptor, DaprClientTimeoutInterceptor from dapr.clients.health import DaprHealth +from dapr.clients.retry import RetryPolicy from dapr.conf import settings from dapr.proto import api_v1, api_service_v1, common_v1 from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse from dapr.version import __version__ from dapr.clients.grpc._helpers import ( - DaprClientInterceptor, MetadataTuple, to_bytes, validateNotNone, @@ -116,8 +117,9 @@ def __init__( ] ] = None, max_grpc_message_length: Optional[int] = None, + retry_policy: Optional[RetryPolicy] = None, ): - """Connects to Dapr Runtime and initialize gRPC client stub. + """Connects to Dapr Runtime and initializes gRPC client stub. Args: address (str, optional): Dapr Runtime gRPC endpoint address. @@ -127,8 +129,10 @@ def __init__( StreamStreamClientInterceptor, optional): gRPC interceptors. max_grpc_messsage_length (int, optional): The maximum grpc send and receive message length in bytes. + retry_policy (RetryPolicy optional): Specifies retry behaviour """ DaprHealth.wait_until_ready() + self.retry_policy = retry_policy or RetryPolicy() useragent = f'dapr-sdk-python/{__version__}' if not max_grpc_message_length: @@ -164,6 +168,8 @@ def __init__( options=options, ) + self._channel = grpc.intercept_channel(self._channel, DaprClientTimeoutInterceptor()) # type: ignore + if settings.DAPR_API_TOKEN: api_token_interceptor = DaprClientInterceptor( [ @@ -323,7 +329,9 @@ def invoke_method( ), ) - response, call = self._stub.InvokeService.with_call(req, metadata=metadata, timeout=timeout) + response, call = self.retry_policy.run_rpc( + self._stub.InvokeService.with_call, req, metadata=metadata, timeout=timeout + ) resp_data = InvokeMethodResponse(response.data, response.content_type) resp_data.headers = call.initial_metadata() # type: ignore @@ -384,7 +392,9 @@ def invoke_binding( operation=operation, ) - response, call = self._stub.InvokeBinding.with_call(req, metadata=metadata) + response, call = self.retry_policy.run_rpc( + self._stub.InvokeBinding.with_call, req, metadata=metadata + ) return BindingResponse(response.data, dict(response.metadata), call.initial_metadata()) def publish_event( @@ -456,7 +466,9 @@ def publish_event( try: # response is google.protobuf.Empty - _, call = self._stub.PublishEvent.with_call(req, metadata=metadata) + _, call = self.retry_policy.run_rpc( + self._stub.PublishEvent.with_call, req, metadata=metadata + ) except RpcError as err: raise DaprGrpcError(err) from err @@ -503,7 +515,9 @@ def get_state( raise ValueError('State store name cannot be empty') req = api_v1.GetStateRequest(store_name=store_name, key=key, metadata=state_metadata) try: - response, call = self._stub.GetState.with_call(req, metadata=metadata) + response, call = self.retry_policy.run_rpc( + self._stub.GetState.with_call, req, metadata=metadata + ) return StateResponse( data=response.data, etag=response.etag, headers=call.initial_metadata() ) @@ -556,7 +570,9 @@ def get_bulk_state( ) try: - response, call = self._stub.GetBulkState.with_call(req, metadata=metadata) + response, call = self.retry_policy.run_rpc( + self._stub.GetBulkState.with_call, req, metadata=metadata + ) except RpcError as err: raise DaprGrpcError(err) from err @@ -618,7 +634,7 @@ def query_state( req = api_v1.QueryStateRequest(store_name=store_name, query=query, metadata=states_metadata) try: - response, call = self._stub.QueryStateAlpha1.with_call(req) + response, call = self.retry_policy.run_rpc(self._stub.QueryStateAlpha1.with_call, req) except RpcError as err: raise DaprGrpcError(err) from err @@ -710,7 +726,9 @@ def save_state( req = api_v1.SaveStateRequest(store_name=store_name, states=[state]) try: - _, call = self._stub.SaveState.with_call(req, metadata=metadata) + _, call = self.retry_policy.run_rpc( + self._stub.SaveState.with_call, req, metadata=metadata + ) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprGrpcError(err) from err @@ -771,7 +789,9 @@ def save_bulk_state( req = api_v1.SaveStateRequest(store_name=store_name, states=req_states) try: - _, call = self._stub.SaveState.with_call(req, metadata=metadata) + _, call = self.retry_policy.run_rpc( + self._stub.SaveState.with_call, req, metadata=metadata + ) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprGrpcError(err) from err @@ -840,7 +860,9 @@ def execute_state_transaction( ) try: - _, call = self._stub.ExecuteStateTransaction.with_call(req, metadata=metadata) + _, call = self.retry_policy.run_rpc( + self._stub.ExecuteStateTransaction.with_call, req, metadata=metadata + ) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprGrpcError(err) from err @@ -908,7 +930,9 @@ def delete_state( ) try: - _, call = self._stub.DeleteState.with_call(req, metadata=metadata) + _, call = self.retry_policy.run_rpc( + self._stub.DeleteState.with_call, req, metadata=metadata + ) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprGrpcError(err) from err @@ -960,7 +984,9 @@ def get_secret( req = api_v1.GetSecretRequest(store_name=store_name, key=key, metadata=secret_metadata) - response, call = self._stub.GetSecret.with_call(req, metadata=metadata) + response, call = self.retry_policy.run_rpc( + self._stub.GetSecret.with_call, req, metadata=metadata + ) return GetSecretResponse(secret=response.data, headers=call.initial_metadata()) @@ -1007,7 +1033,9 @@ def get_bulk_secret( req = api_v1.GetBulkSecretRequest(store_name=store_name, metadata=secret_metadata) - response, call = self._stub.GetBulkSecret.with_call(req, metadata=metadata) + response, call = self.retry_policy.run_rpc( + self._stub.GetBulkSecret.with_call, req, metadata=metadata + ) secrets_map = {} for key in response.data.keys(): @@ -1047,7 +1075,7 @@ def get_configuration( req = api_v1.GetConfigurationRequest( store_name=store_name, keys=keys, metadata=config_metadata ) - response, call = self._stub.GetConfiguration.with_call(req) + response, call = self.retry_policy.run_rpc(self._stub.GetConfiguration.with_call, req) return ConfigurationResponse(items=response.items, headers=call.initial_metadata()) def subscribe_configuration( @@ -1155,7 +1183,7 @@ def try_lock( lock_owner=lock_owner, expiry_in_seconds=expiry_in_seconds, ) - response, call = self._stub.TryLockAlpha1.with_call(req) + response, call = self.retry_policy.run_rpc(self._stub.TryLockAlpha1.with_call, req) return TryLockResponse( success=response.success, client=self, @@ -1193,7 +1221,7 @@ def unlock(self, store_name: str, resource_id: str, lock_owner: str) -> UnlockRe req = api_v1.UnlockRequest( store_name=store_name, resource_id=resource_id, lock_owner=lock_owner ) - response, call = self._stub.UnlockAlpha1.with_call(req) + response, call = self.retry_policy.run_rpc(self._stub.UnlockAlpha1.with_call, req) return UnlockResponse( status=UnlockResponseStatus(response.status), headers=call.initial_metadata() @@ -1289,7 +1317,7 @@ def get_workflow(self, instance_id: str, workflow_component: str) -> GetWorkflow ) try: - resp = self._stub.GetWorkflowBeta1(req) + resp = self.retry_policy.run_rpc(self._stub.GetWorkflowBeta1, req) if resp.created_at is None: resp.created_at = datetime.now() if resp.last_updated_at is None: @@ -1331,7 +1359,7 @@ def terminate_workflow(self, instance_id: str, workflow_component: str) -> DaprR ) try: - _, call = self._stub.TerminateWorkflowBeta1.with_call(req) + _, call = self.retry_policy.run_rpc(self._stub.TerminateWorkflowBeta1.with_call, req) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprInternalError(err.details()) @@ -1402,7 +1430,7 @@ def raise_workflow_event( ) try: - _, call = self._stub.RaiseEventWorkflowBeta1.with_call(req) + _, call = self.retry_policy.run_rpc(self._stub.RaiseEventWorkflowBeta1.with_call, req) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: raise DaprInternalError(err.details()) @@ -1433,7 +1461,7 @@ def pause_workflow(self, instance_id: str, workflow_component: str) -> DaprRespo ) try: - _, call = self._stub.PauseWorkflowBeta1.with_call(req) + _, call = self.retry_policy.run_rpc(self._stub.PauseWorkflowBeta1.with_call, req) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: @@ -1464,7 +1492,7 @@ def resume_workflow(self, instance_id: str, workflow_component: str) -> DaprResp ) try: - _, call = self._stub.ResumeWorkflowBeta1.with_call(req) + _, call = self.retry_policy.run_rpc(self._stub.ResumeWorkflowBeta1.with_call, req) return DaprResponse(headers=call.initial_metadata()) except RpcError as err: @@ -1495,7 +1523,7 @@ def purge_workflow(self, instance_id: str, workflow_component: str) -> DaprRespo ) try: - _, call = self._stub.PurgeWorkflowBeta1.with_call(req) + response, call = self.retry_policy.run_rpc(self._stub.PurgeWorkflowBeta1.with_call, req) return DaprResponse(headers=call.initial_metadata()) @@ -1551,7 +1579,7 @@ def get_metadata(self) -> GetMetadataResponse: capabilities. """ try: - _resp, call = self._stub.GetMetadata.with_call(GrpcEmpty()) + _resp, call = self.retry_policy.run_rpc(self._stub.GetMetadata.with_call, GrpcEmpty()) except RpcError as err: raise DaprGrpcError(err) from err @@ -1597,7 +1625,7 @@ def set_metadata(self, attributeName: str, attributeValue: str) -> DaprResponse: validateNotNone(attributeValue=attributeValue) # Actual invocation req = api_v1.SetMetadataRequest(key=attributeName, value=attributeValue) - _, call = self._stub.SetMetadata.with_call(req) + _, call = self.retry_policy.run_rpc(self._stub.SetMetadata.with_call, req) return DaprResponse(call.initial_metadata()) @@ -1617,6 +1645,6 @@ def shutdown(self) -> DaprResponse: :class:`DaprResponse` gRPC metadata returned from callee """ - _, call = self._stub.Shutdown.with_call(GrpcEmpty()) + _, call = self.retry_policy.run_rpc(self._stub.Shutdown.with_call, GrpcEmpty()) return DaprResponse(call.initial_metadata()) diff --git a/dapr/clients/grpc/interceptors.py b/dapr/clients/grpc/interceptors.py new file mode 100644 index 00000000..22098f53 --- /dev/null +++ b/dapr/clients/grpc/interceptors.py @@ -0,0 +1,110 @@ +from collections import namedtuple +from typing import List, Tuple + +from grpc import UnaryUnaryClientInterceptor, ClientCallDetails # type: ignore + +from dapr.conf import settings + + +class _ClientCallDetails( + namedtuple( + '_ClientCallDetails', + ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'], + ), + ClientCallDetails, +): + """This is an implementation of the ClientCallDetails interface needed for interceptors. + This class takes six named values and inherits the ClientCallDetails from grpc package. + This class encloses the values that describe a RPC to be invoked. + """ + + pass + + +class DaprClientTimeoutInterceptor(UnaryUnaryClientInterceptor): + def intercept_unary_unary(self, continuation, client_call_details, request): + # If a specific timeout is not set, create a new ClientCallDetails with the default timeout + if client_call_details.timeout is None: + new_client_call_details = _ClientCallDetails( + client_call_details.method, + settings.DAPR_API_TIMEOUT_SECONDS, + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + client_call_details.compression, + ) + return continuation(new_client_call_details, request) + + return continuation(client_call_details, request) + + +class DaprClientInterceptor(UnaryUnaryClientInterceptor): + """The class implements a UnaryUnaryClientInterceptor from grpc to add an interceptor to add + additional headers to all calls as needed. + + Examples: + + interceptor = HeaderInterceptor([('header', 'value', )]) + intercepted_channel = grpc.intercept_channel(grpc_channel, interceptor) + + With multiple header values: + + interceptor = HeaderInterceptor([('header1', 'value1', ), ('header2', 'value2', )]) + intercepted_channel = grpc.intercept_channel(grpc_channel, interceptor) + """ + + def __init__(self, metadata: List[Tuple[str, str]]): + """Initializes the metadata field for the class. + + Args: + metadata list[tuple[str, str]]: list of tuple of (key, value) strings + representing header values + """ + + self._metadata = metadata + + def _intercept_call(self, client_call_details: ClientCallDetails) -> ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details. + + Args: + client_call_details :class: `ClientCallDetails`: object that describes a RPC + to be invoked + + Returns: + :class: `ClientCallDetails` modified call details + """ + + metadata = [] + if client_call_details.metadata is not None: + metadata = list(client_call_details.metadata) + metadata.extend(self._metadata) + + new_call_details = _ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + client_call_details.compression, + ) + return new_call_details + + def intercept_unary_unary(self, continuation, client_call_details, request): + """This method intercepts a unary-unary gRPC call. This is the implementation of the + abstract method defined in UnaryUnaryClientInterceptor defined in grpc. This is invoked + automatically by grpc based on the order in which interceptors are added to the channel. + + Args: + continuation: a callable to be invoked to continue with the RPC or next interceptor + client_call_details: a ClientCallDetails object describing the outgoing RPC + request: the request value for the RPC + + Returns: + A response object after invoking the continuation callable + """ + # Pre-process or intercept call + new_call_details = self._intercept_call(client_call_details) + # Call continuation + response = continuation(new_call_details, request) + return response diff --git a/dapr/clients/http/client.py b/dapr/clients/http/client.py index 0d591156..6f2a8e3d 100644 --- a/dapr/clients/http/client.py +++ b/dapr/clients/http/client.py @@ -17,13 +17,14 @@ from typing import Callable, Mapping, Dict, Optional, Union, Tuple, TYPE_CHECKING +from dapr.clients.health import DaprHealth from dapr.clients.http.conf import ( DAPR_API_TOKEN_HEADER, USER_AGENT_HEADER, DAPR_USER_AGENT, CONTENT_TYPE_HEADER, ) -from dapr.clients.health import DaprHealth +from dapr.clients.retry import RetryPolicy if TYPE_CHECKING: from dapr.serializers import Serializer @@ -41,6 +42,7 @@ def __init__( message_serializer: 'Serializer', timeout: Optional[int] = 60, headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + retry_policy: Optional[RetryPolicy] = None, ): """Invokes Dapr over HTTP. @@ -54,6 +56,7 @@ def __init__( self._timeout = aiohttp.ClientTimeout(total=timeout) self._serializer = message_serializer self._headers_callback = headers_callback + self.retry_policy = retry_policy or RetryPolicy() async def send_bytes( self, @@ -81,17 +84,19 @@ async def send_bytes( client_timeout = aiohttp.ClientTimeout(total=timeout) if timeout else self._timeout sslcontext = self.get_ssl_context() - async with aiohttp.ClientSession(timeout=client_timeout) as session: - r = await session.request( - method=method, - url=url, - data=data, - headers=headers_map, - ssl=sslcontext, - params=query_params, - ) - - if r.status >= 200 and r.status < 300: + async with aiohttp.ClientSession() as session: + req = { + 'method': method, + 'url': url, + 'data': data, + 'headers': headers_map, + 'sslcontext': sslcontext, + 'params': query_params, + 'timeout': client_timeout, + } + r = await self.retry_policy.make_http_call(session, req) + + if 200 <= r.status < 300: return await r.read(), r raise (await self.convert_to_error(r)) diff --git a/dapr/clients/http/dapr_actor_http_client.py b/dapr/clients/http/dapr_actor_http_client.py index a4fccfc1..186fdbc1 100644 --- a/dapr/clients/http/dapr_actor_http_client.py +++ b/dapr/clients/http/dapr_actor_http_client.py @@ -22,6 +22,7 @@ from dapr.clients.http.client import DaprHttpClient from dapr.clients.base import DaprActorClientBase +from dapr.clients.retry import RetryPolicy DAPR_REENTRANCY_ID_HEADER = 'Dapr-Reentrancy-Id' @@ -34,6 +35,7 @@ def __init__( message_serializer: 'Serializer', timeout: int = 60, headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + retry_policy: Optional[RetryPolicy] = None, ): """Invokes Dapr Actors over HTTP. @@ -41,8 +43,9 @@ def __init__( message_serializer (Serializer): Dapr serializer. timeout (int, optional): Timeout in seconds, defaults to 60. headers_callback (lambda: Dict[str, str]], optional): Generates header for each request. + retry_policy (RetryPolicy optional): Specifies retry behaviour """ - self._client = DaprHttpClient(message_serializer, timeout, headers_callback) + self._client = DaprHttpClient(message_serializer, timeout, headers_callback, retry_policy) async def invoke_method( self, actor_type: str, actor_id: str, method: str, data: Optional[bytes] = None diff --git a/dapr/clients/http/dapr_invocation_http_client.py b/dapr/clients/http/dapr_invocation_http_client.py index ca1a5dfa..df4e6d22 100644 --- a/dapr/clients/http/dapr_invocation_http_client.py +++ b/dapr/clients/http/dapr_invocation_http_client.py @@ -23,6 +23,7 @@ from dapr.clients.grpc._response import InvokeMethodResponse from dapr.clients.http.conf import CONTENT_TYPE_HEADER from dapr.clients.http.helpers import get_api_url +from dapr.clients.retry import RetryPolicy from dapr.serializers import DefaultJSONSerializer from dapr.version import __version__ @@ -34,15 +35,21 @@ class DaprInvocationHttpClient: """Service Invocation HTTP Client""" def __init__( - self, timeout: int = 60, headers_callback: Optional[Callable[[], Dict[str, str]]] = None + self, + timeout: int = 60, + headers_callback: Optional[Callable[[], Dict[str, str]]] = None, + retry_policy: Optional[RetryPolicy] = None, ): """Invokes Dapr's API for method invocation over HTTP. Args: timeout (int, optional): Timeout in seconds, defaults to 60. headers_callback (lambda: Dict[str, str]], optional): Generates header for each request. + retry_policy (RetryPolicy optional): Specifies retry behaviour """ - self._client = DaprHttpClient(DefaultJSONSerializer(), timeout, headers_callback) + self._client = DaprHttpClient( + DefaultJSONSerializer(), timeout, headers_callback, retry_policy=retry_policy + ) async def invoke_method_async( self, diff --git a/dapr/clients/retry.py b/dapr/clients/retry.py new file mode 100644 index 00000000..171c96fb --- /dev/null +++ b/dapr/clients/retry.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import asyncio +from typing import Optional, List, Callable + +from grpc import RpcError, StatusCode # type: ignore +import time + +from dapr.conf import settings + + +class RetryPolicy: + """RetryPolicy holds the retry policy configuration for a gRPC client. + + Args: + max_attempts (int): The maximum number of retry attempts. + initial_backoff (int): The initial backoff duration. + max_backoff (int): The maximum backoff duration. + backoff_multiplier (float): The backoff multiplier. + retryable_http_status_codes (List[int]): The list of http retryable status codes + retryable_grpc_status_codes (List[StatusCode]): The list of retryable grpc status codes + """ + + def __init__( + self, + max_attempts: Optional[int] = settings.DAPR_API_MAX_RETRIES, + initial_backoff: int = 1, + max_backoff: int = 20, + backoff_multiplier: float = 1.5, + retryable_http_status_codes: List[int] = [408, 429, 500, 502, 503, 504], + retryable_grpc_status_codes: List[StatusCode] = [ + StatusCode.UNAVAILABLE, + StatusCode.DEADLINE_EXCEEDED, + ], + ): + if max_attempts < -1: # type: ignore + raise ValueError('max_attempts must be greater than or equal to -1') + self.max_attempts = max_attempts + + if initial_backoff < 1: + raise ValueError('initial_backoff must be greater than or equal to 1') + self.initial_backoff = initial_backoff + + if max_backoff < 1: + raise ValueError('max_backoff must be greater than or equal to 1') + self.max_backoff = max_backoff + + if backoff_multiplier < 1: + raise ValueError('backoff_multiplier must be greater than or equal to 1') + self.backoff_multiplier = backoff_multiplier + + if len(retryable_http_status_codes) == 0: + raise ValueError("retryable_http_status_codes can't be empty") + self.retryable_http_status_codes = retryable_http_status_codes + + if len(retryable_grpc_status_codes) == 0: + raise ValueError("retryable_http_status_codes can't be empty") + self.retryable_grpc_status_codes = retryable_grpc_status_codes + + def run_rpc(self, func=Callable, *args, **kwargs): + # If max_retries is 0, we don't retry + if self.max_attempts == 0: + return func(*args, **kwargs) + + attempt = 0 + while self.max_attempts == -1 or attempt < self.max_attempts: # type: ignore + try: + print(f'Trying RPC call, attempt {attempt + 1}') + return func(*args, **kwargs) + except RpcError as err: + if err.code() not in self.retryable_grpc_status_codes: + raise + if self.max_attempts != -1 and attempt == self.max_attempts - 1: # type: ignore + raise + sleep_time = min( + self.max_backoff, + self.initial_backoff * (self.backoff_multiplier**attempt), + ) + print(f'Sleeping for {sleep_time} seconds before retrying RPC call') + time.sleep(sleep_time) + attempt += 1 + raise Exception(f'RPC call failed after {attempt} retries') + + async def run_rpc_async(self, func: Callable, *args, **kwargs): + # If max_retries is 0, we don't retry + if self.max_attempts == 0: + call = func(*args, **kwargs) + result = await call + return result, call + + attempt = 0 + while self.max_attempts == -1 or attempt < self.max_attempts: # type: ignore + try: + print(f'Trying RPC call, attempt {attempt + 1}') + call = func(*args, **kwargs) + result = await call + return result, call + except RpcError as err: + if err.code() not in self.retryable_grpc_status_codes: + raise + if self.max_attempts != -1 and attempt == self.max_attempts - 1: # type: ignore + raise + sleep_time = min( + self.max_backoff, + self.initial_backoff * (self.backoff_multiplier**attempt), + ) + print(f'Sleeping for {sleep_time} seconds before retrying RPC call') + await asyncio.sleep(sleep_time) + attempt += 1 + raise Exception(f'RPC call failed after {attempt} retries') + + async def make_http_call(self, session, req): + # If max_retries is 0, we don't retry + if self.max_attempts == 0: + return await session.request( + method=req['method'], + url=req['url'], + data=req['data'], + headers=req['headers'], + ssl=req['sslcontext'], + params=req['params'], + timeout=req['timeout'], + ) + + attempt = 0 + while self.max_attempts == -1 or attempt < self.max_attempts: # type: ignore + print(f'Request attempt {attempt + 1}') + r = await session.request( + method=req['method'], + url=req['url'], + data=req['data'], + headers=req['headers'], + ssl=req['sslcontext'], + params=req['params'], + timeout=req['timeout'], + ) + + if r.status not in self.retryable_http_status_codes: + return r + + if ( + self.max_attempts != -1 and attempt == self.max_attempts - 1 # type: ignore + ): # type: ignore + return r + + sleep_time = min( + self.max_backoff, + self.initial_backoff * (self.backoff_multiplier**attempt), + ) + + print(f'Sleeping for {sleep_time} seconds before retrying call') + await asyncio.sleep(sleep_time) + attempt += 1 diff --git a/dapr/conf/__init__.py b/dapr/conf/__init__.py index 1a25fc10..7fbe5f2f 100644 --- a/dapr/conf/__init__.py +++ b/dapr/conf/__init__.py @@ -23,7 +23,13 @@ def __init__(self): for setting in dir(global_settings): default_value = getattr(global_settings, setting) env_variable = os.environ.get(setting) - setattr(self, setting, env_variable or default_value) + if env_variable: + val = ( + type(default_value)(env_variable) if default_value is not None else env_variable + ) + setattr(self, setting, val) + else: + setattr(self, setting, default_value) def __getattr__(self, name): if name not in dir(global_settings): diff --git a/dapr/conf/global_settings.py b/dapr/conf/global_settings.py index b7cb885b..43bb51f6 100644 --- a/dapr/conf/global_settings.py +++ b/dapr/conf/global_settings.py @@ -27,6 +27,9 @@ DAPR_API_VERSION = 'v1.0' DAPR_HEALTH_TIMEOUT = 60 # seconds +DAPR_API_MAX_RETRIES = 0 +DAPR_API_TIMEOUT_SECONDS = 60 + DAPR_API_METHOD_INVOCATION_PROTOCOL = 'http' DAPR_HTTP_TIMEOUT_SECONDS = 60 diff --git a/daprdocs/content/en/python-sdk-docs/python-client.md b/daprdocs/content/en/python-sdk-docs/python-client.md index 3030f64a..54c8cc30 100644 --- a/daprdocs/content/en/python-sdk-docs/python-client.md +++ b/daprdocs/content/en/python-sdk-docs/python-client.md @@ -49,7 +49,7 @@ with DaprClient("mydomain:50051?tls=true") as d: # use the client ``` -#### Environment variables: +#### Configuration options: ##### Dapr Sidecar Endpoints You can use the standardised `DAPR_GRPC_ENDPOINT` environment variable to @@ -75,12 +75,47 @@ set it in the environment and the client will use it automatically. You can read more about Dapr API token authentication [here](https://docs.dapr.io/operations/security/api-token/). ##### Health timeout -On client initialisation, a health check is performed against the Dapr sidecar (`/healthz/outboud`). +On client initialisation, a health check is performed against the Dapr sidecar (`/healthz/outbound`). The client will wait for the sidecar to be up and running before proceeding. -The default timeout is 60 seconds, but it can be overridden by setting the `DAPR_HEALTH_TIMEOUT` +The default healthcheck timeout is 60 seconds, but it can be overridden by setting the `DAPR_HEALTH_TIMEOUT` environment variable. +##### Retries and timeout + +The Dapr client can retry a request if a specific error code is received from the sidecar. This is +configurable through the `DAPR_API_MAX_RETRIES` environment variable and is picked up automatically, +not requiring any code changes. +The default value for `DAPR_API_MAX_RETRIES` is `0`, which means no retries will be made. + +You can fine-tune more retry parameters by creating a `dapr.clients.retry.RetryPolicy` object and +passing it to the DaprClient constructor: + +```python +from dapr.clients.retry import RetryPolicy + +retry = RetryPolicy( + max_attempts=5, + initial_backoff=1, + max_backoff=20, + backoff_multiplier=1.5, + retryable_http_status_codes=[408, 429, 500, 502, 503, 504], + retryable_grpc_status_codes=[StatusCode.UNAVAILABLE, StatusCode.DEADLINE_EXCEEDED, ] +) + +with DaprClient(retry_policy=retry) as d: + ... +``` + +or for actors: +```python +factory = ActorProxyFactory(retry_policy=RetryPolicy(max_attempts=3)) +proxy = ActorProxy.create('DemoActor', ActorId('1'), DemoActorInterface, factory) +``` + +**Timeout** can be set for all calls through the environment variable `DAPR_API_TIMEOUT_SECONDS`. The default value is 60 seconds. + +> Note: You can control timeouts on service invocation separately, by passing a `timeout` parameter to the `invoke_method` method. ## Error handling Initially, errors in Dapr followed the [Standard gRPC error model](https://grpc.io/docs/guides/error/#standard-error-model). However, to provide more detailed and informative error messages, in version 1.13 an enhanced error model has been introduced which aligns with the gRPC [Richer error model](https://grpc.io/docs/guides/error/#richer-error-model). In response, the Python SDK implemented `DaprGrpcError`, a custom exception class designed to improve the developer experience. diff --git a/examples/demo_actor/demo_actor/demo_actor_client.py b/examples/demo_actor/demo_actor/demo_actor_client.py index 67912aeb..df0e9f73 100644 --- a/examples/demo_actor/demo_actor/demo_actor_client.py +++ b/examples/demo_actor/demo_actor/demo_actor_client.py @@ -12,13 +12,15 @@ import asyncio -from dapr.actor import ActorProxy, ActorId +from dapr.actor import ActorProxy, ActorId, ActorProxyFactory +from dapr.clients.retry import RetryPolicy from demo_actor_interface import DemoActorInterface async def main(): # Create proxy client - proxy = ActorProxy.create('DemoActor', ActorId('1'), DemoActorInterface) + factory = ActorProxyFactory(retry_policy=RetryPolicy(max_attempts=3)) + proxy = ActorProxy.create('DemoActor', ActorId('1'), DemoActorInterface, factory) # ----------------------------------------------- # Actor invocation demo diff --git a/examples/error_handling/README.md b/examples/error_handling/README.md index cca285f5..1d24bc82 100644 --- a/examples/error_handling/README.md +++ b/examples/error_handling/README.md @@ -1,6 +1,6 @@ # Example - Error handling -This guide demonstrates handling `DaprGrpcError` errors when using the Dapr python-SDK. It's important to note that not all Dapr gRPC status errors are currently captured and transformed into a `DaprGrpcError` by the SDK. Efforts are ongoing to enhance this aspect, and contributions are welcome. For detailed information on error handling in Dapr, refer to the [official documentation](https://docs.dapr.io/reference/). +This guide demonstrates handling `DaprGrpcError` errors when using the Dapr python-SDK. It's important to note that not all Dapr gRPC status errors are currently captured and transformed into a `DaprGrpcError` by the SDK. Efforts are ongoing to enhance this aspect, and contributions are welcome. For detailed information on error handling in Dapr, refer to the [official documentation](https://docs.dapr.io/reference/errors). The example involves creating a DaprClient and invoking the save_state method. It uses the default configuration from Dapr init in [self-hosted mode](https://github.com/dapr/cli#install-dapr-on-your-local-machine-self-hosted). diff --git a/tests/clients/test_client_interceptor.py b/tests/clients/test_client_interceptor.py index da654151..52ca9618 100644 --- a/tests/clients/test_client_interceptor.py +++ b/tests/clients/test_client_interceptor.py @@ -15,7 +15,7 @@ import unittest -from dapr.clients.grpc._helpers import DaprClientInterceptor, _ClientCallDetails +from dapr.clients.grpc.interceptors import DaprClientInterceptor, _ClientCallDetails class DaprClientInterceptorTests(unittest.TestCase): diff --git a/tests/clients/test_retries_policy.py b/tests/clients/test_retries_policy.py new file mode 100644 index 00000000..b5137e64 --- /dev/null +++ b/tests/clients/test_retries_policy.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from unittest import mock +from unittest.mock import Mock, MagicMock, patch, AsyncMock + +from grpc import StatusCode, RpcError + +from dapr.clients.retry import RetryPolicy +from dapr.serializers import DefaultJSONSerializer + + +class RetryPolicyTests(unittest.TestCase): + async def httpSetUp(self): + # Setup your test environment and mocks here + self.session = MagicMock() + self.session.request = AsyncMock() + + self.serializer = (DefaultJSONSerializer(),) + + # Example request + self.req = { + 'method': 'GET', + 'url': 'http://example.com', + 'data': None, + 'headers': None, + 'sslcontext': None, + 'params': None, + 'timeout': None, + } + + def test_init_success_default(self): + policy = RetryPolicy() + + self.assertEqual(0, policy.max_attempts) + self.assertEqual(1, policy.initial_backoff) + self.assertEqual(20, policy.max_backoff) + self.assertEqual(1.5, policy.backoff_multiplier) + self.assertEqual([408, 429, 500, 502, 503, 504], policy.retryable_http_status_codes) + self.assertEqual( + [StatusCode.UNAVAILABLE, StatusCode.DEADLINE_EXCEEDED], + policy.retryable_grpc_status_codes, + ) + + def test_init_success(self): + policy = RetryPolicy( + max_attempts=3, + initial_backoff=2, + max_backoff=10, + backoff_multiplier=2, + retryable_grpc_status_codes=[StatusCode.UNAVAILABLE], + retryable_http_status_codes=[408, 429], + ) + self.assertEqual(3, policy.max_attempts) + self.assertEqual(2, policy.initial_backoff) + self.assertEqual(10, policy.max_backoff) + self.assertEqual(2, policy.backoff_multiplier) + self.assertEqual([StatusCode.UNAVAILABLE], policy.retryable_grpc_status_codes) + self.assertEqual([408, 429], policy.retryable_http_status_codes) + + def test_init_with_errors(self): + with self.assertRaises(ValueError): + RetryPolicy(max_attempts=-2) + + with self.assertRaises(ValueError): + RetryPolicy(initial_backoff=0) + + with self.assertRaises(ValueError): + RetryPolicy(max_backoff=0) + + with self.assertRaises(ValueError): + RetryPolicy(backoff_multiplier=0) + + with self.assertRaises(ValueError): + RetryPolicy(retryable_http_status_codes=[]) + + with self.assertRaises(ValueError): + RetryPolicy(retryable_grpc_status_codes=[]) + + def test_run_rpc_with_retry_success(self): + mock_func = Mock(return_value='success') + + policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) + result = policy.run_rpc(mock_func, 'foo', 'bar', arg1=1, arg2=2) + + self.assertEqual(result, 'success') + mock_func.assert_called_once_with('foo', 'bar', arg1=1, arg2=2) + + def test_run_rpc_with_retry_no_retry(self): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = MagicMock(side_effect=mock_error) + + policy = RetryPolicy(max_attempts=0) + with self.assertRaises(RpcError): + policy.run_rpc(mock_func) + mock_func.assert_called_once() + + @patch('time.sleep', return_value=None) # To speed up tests + def test_run_rpc_with_retry_fail(self, mock_sleep): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = MagicMock(side_effect=mock_error) + with self.assertRaises(RpcError): + policy = RetryPolicy(max_attempts=4, initial_backoff=2, backoff_multiplier=1.5) + policy.run_rpc(mock_func) + + self.assertEqual(mock_func.call_count, 4) + expected_sleep_calls = [ + mock.call(2.0), # First sleep call + mock.call(3.0), # Second sleep call + mock.call(4.5), # Third sleep call + ] + mock_sleep.assert_has_calls(expected_sleep_calls, any_order=False) + + def test_run_rpc_with_retry_fail_with_another_status_code(self): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.FAILED_PRECONDITION) + mock_func = MagicMock(side_effect=mock_error) + + with self.assertRaises(RpcError): + policy = RetryPolicy( + max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE] + ) + policy.run_rpc(mock_func) + + mock_func.assert_called_once() + + @patch('time.sleep', return_value=None) # To speed up tests + def test_run_rpc_with_retry_fail_with_max_backoff(self, mock_sleep): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = MagicMock(side_effect=mock_error) + with self.assertRaises(RpcError): + policy = RetryPolicy( + max_attempts=4, initial_backoff=2, backoff_multiplier=1.5, max_backoff=3 + ) + policy.run_rpc( + mock_func, + ) + + self.assertEqual(mock_func.call_count, 4) + expected_sleep_calls = [ + mock.call(2.0), # First sleep call + mock.call(3.0), # Second sleep call + mock.call(3.0), # Third sleep call + ] + mock_sleep.assert_has_calls(expected_sleep_calls, any_order=False) + + @patch('time.sleep', return_value=None) # To speed up tests + def test_run_rpc_with_infinite_retries(self, mock_sleep): + # Testing a function that's supposed to run forever is tricky, so we'll simulate it + # Instead of a fixed side effect, we'll create a function that's supposed to + # break out of the cycle after X calls. + # Then we assert that the function was called X times before breaking the loop + + # Configure the policy to simulate infinite retries + policy = RetryPolicy(max_attempts=-1, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) + + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = MagicMock() + + # Use a side effect on the mock to count calls and eventually interrupt the loop + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count >= 10: # Five calls before breaking the loop + raise Exception('Test interrupt') + + raise mock_error + + mock_func.side_effect = side_effect + + # Run the test, expecting the custom exception to break the loop + with self.assertRaises(Exception) as context: + policy.run_rpc(mock_func) + + self.assertEqual(str(context.exception), 'Test interrupt') + + # Verify the function was retried the expected number of times before interrupting + self.assertEqual(call_count, 10) + + # Test retrying async rpc calls + async def test_run_rpc_async_with_retry_success(self): + mock_func = AsyncMock(return_value='success') + + policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) + result, _ = await policy.async_run_rpc(mock_func, 'foo', arg1=1, arg2=2) + + self.assertEqual(result, 'success') + mock_func.assert_awaited_once_with('foo', arg1=1, arg2=2) + + async def test_run_rpc_async_with_retry_no_retry(self): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = AsyncMock(side_effect=mock_error) + + with self.assertRaises(RpcError): + policy = RetryPolicy(max_attempts=0) + await policy.async_run_rpc(mock_func) + mock_func.assert_awaited_once() + + # Test retrying http requests + async def test_http_call_with_success(self): + # Mock the request to succeed on the first try + self.session.request.return_value.status = 200 + + policy = RetryPolicy() + response = await policy.make_http_call(self.session, self.req) + + self.session.request.assert_called_once() + self.assertEqual(200, response.status) + + async def test_http_call_success_with_no_retry(self): + self.session.request.return_value.status = 200 + + policy = RetryPolicy(max_attempts=0) + response = await policy.make_http_call(self.session, self.req) + + self.session.request.assert_called_once() + self.assertEqual(200, response.status) + + async def test_http_call_fail_with_no_retry(self): + self.session.request.return_value.status = 408 + + policy = RetryPolicy(max_attempts=0) + response = await policy.make_http_call(self.session, self.req) + + self.session.request.assert_called_once() + self.assertEqual(408, response.status) + + @patch('asyncio.sleep', return_value=None) + async def test_http_call_retry_eventually_succeeds(self, _): + # Mock the request to fail twice then succeed + self.session.request.side_effect = [ + MagicMock(status=500), # First attempt fails + MagicMock(status=502), # Second attempt fails + MagicMock(status=200), # Third attempt succeeds + ] + + policy = RetryPolicy(max_attempts=3) + response = await policy.make_http_call(self.session, self.req) + + self.assertEqual(3, self.session.request.call_count) + self.assertEqual(200, response.status) + + @patch('asyncio.sleep', return_value=None) + async def test_http_call_retry_eventually_fails(self, _): + self.session.request.return_value.status = 408 + + policy = RetryPolicy(max_attempts=3) + response = await policy.make_http_call(self.session, self.req) + + self.assertEqual(3, self.session.request.call_count) + self.assertEqual(408, response.status) + + @patch('asyncio.sleep', return_value=None) + async def test_http_call_retry_fails_with_a_different_code(self, _): + # Mock the request to fail twice then succeed + self.session.request.return_value.status = 501 + + policy = RetryPolicy(max_attempts=3, retryable_http_status_codes=[500]) + response = await policy.make_http_call(self.session, self.req) + + self.session.request.assert_called_once() + self.assertEqual(response.status, 501) + + @patch('asyncio.sleep', return_value=None) + async def test_http_call_retries_exhausted(self, _): + # Mock the request to fail three times + self.session.request.return_value = MagicMock(status=500) + + policy = RetryPolicy(max_attempts=3, retryable_http_status_codes=[500]) + response = await policy.make_http_call(self.session, self.req) + + self.assertEqual(3, self.session.request.call_count) + self.assertEqual(500, response.status) + + @patch('asyncio.sleep', return_value=None) + async def test_http_call_max_backoff(self, mock_sleep): + self.session.request.return_value.status = 500 + + policy = RetryPolicy(max_attempts=4, initial_backoff=2, backoff_multiplier=2, max_backoff=3) + response = await policy.make_http_call(self.session, self.req) + + expected_sleep_calls = [ + mock.call(2.0), # First sleep call + mock.call(3.0), # Second sleep call + mock.call(3.0), # Third sleep call + ] + self.assertEqual(4, self.session.request.call_count) + mock_sleep.assert_has_calls(expected_sleep_calls, any_order=False) + self.assertEqual(500, response.status) + + @patch('asyncio.sleep', return_value=None) + async def test_http_call_infinite_retries(self, _): + retry_count = 0 + max_test_retries = 6 # Simulates "indefinite" retries for test purposes + + # Function to simulate request behavior + async def mock_request(*args, **kwargs): + nonlocal retry_count + retry_count += 1 + if retry_count < max_test_retries: + return MagicMock(status=500) # Simulate failure + else: + return MagicMock(status=200) # Simulate success to stop retrying + + self.session.request = mock_request + + policy = RetryPolicy(max_attempts=-1, retryable_http_status_codes=[500]) + response = await policy.make_http_call(self.session, self.req) + + # Assert that the retry logic was executed the expected number of times + self.assertEqual(response.status, 200) + self.assertEqual(retry_count, max_test_retries) diff --git a/tests/clients/test_retries_policy_async.py b/tests/clients/test_retries_policy_async.py new file mode 100644 index 00000000..ebe6865d --- /dev/null +++ b/tests/clients/test_retries_policy_async.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from unittest import mock +from unittest.mock import MagicMock, patch, AsyncMock + +from grpc import StatusCode, RpcError + +from dapr.clients.retry import RetryPolicy + + +class RetryPolicyGrpcAsyncTests(unittest.IsolatedAsyncioTestCase): + async def test_run_rpc_async_with_retry_success(self): + mock_func = AsyncMock(return_value='success') + + policy = RetryPolicy(max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) + result, _ = await policy.run_rpc_async(mock_func, 'foo', arg1=1, arg2=2) + + self.assertEqual(result, 'success') + mock_func.assert_awaited_once_with('foo', arg1=1, arg2=2) + + async def test_run_rpc_async_with_retry_no_retry(self): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = AsyncMock(side_effect=mock_error) + + with self.assertRaises(RpcError): + policy = RetryPolicy(max_attempts=0) + await policy.run_rpc_async(mock_func) + mock_func.assert_awaited_once() + + @patch('asyncio.sleep', return_value=None) + async def test_run_rpc_async_with_retry_fail(self, mock_sleep): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = AsyncMock(side_effect=mock_error) + with self.assertRaises(RpcError): + policy = RetryPolicy(max_attempts=4, initial_backoff=2, backoff_multiplier=1.5) + await policy.run_rpc_async(mock_func) + + self.assertEqual(mock_func.await_count, 4) + expected_sleep_calls = [ + mock.call(2.0), # First sleep call + mock.call(3.0), # Second sleep call + mock.call(4.5), # Third sleep call + ] + mock_sleep.assert_has_calls(expected_sleep_calls, any_order=False) + + async def test_run_rpc_async_with_retry_fail_with_another_status_code(self): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.FAILED_PRECONDITION) + mock_func = AsyncMock(side_effect=mock_error) + + with self.assertRaises(RpcError): + policy = RetryPolicy( + max_attempts=3, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE] + ) + await policy.run_rpc_async(mock_func) + + mock_func.assert_awaited_once() + + @patch('asyncio.sleep', return_value=None) + async def test_run_rpc_async_with_retry_fail_with_max_backoff(self, mock_sleep): + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = AsyncMock(side_effect=mock_error) + + with self.assertRaises(RpcError): + policy = RetryPolicy( + max_attempts=4, initial_backoff=2, backoff_multiplier=1.5, max_backoff=3 + ) + await policy.run_rpc_async(mock_func) + + self.assertEqual(mock_func.await_count, 4) + expected_sleep_calls = [ + mock.call(2.0), # First sleep call + mock.call(3.0), # Second sleep call + mock.call(3.0), # Third sleep call + ] + mock_sleep.assert_has_calls(expected_sleep_calls, any_order=False) + + @patch('asyncio.sleep', return_value=None) + async def test_run_rpc_async_with_infinite_retries(self, mock_sleep): + # Testing a function that's supposed to run forever is tricky, so we'll simulate it + # Instead of a fixed side effect, we'll create a function that's supposed to + # break out of the cycle after X calls. + # Then we assert that the function was called X times before breaking the loop + + # Configure the policy to simulate infinite retries + policy = RetryPolicy(max_attempts=-1, retryable_grpc_status_codes=[StatusCode.UNAVAILABLE]) + + mock_error = RpcError() + mock_error.code = MagicMock(return_value=StatusCode.UNAVAILABLE) + mock_func = AsyncMock() + + # Use a side effect on the mock to count calls and eventually interrupt the loop + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count >= 10: # Five calls before breaking the loop + raise Exception('Test interrupt') + + raise mock_error + + mock_func.side_effect = side_effect + + # Run the test, expecting the custom exception to break the loop + with self.assertRaises(Exception) as context: + await policy.run_rpc_async(mock_func) + + self.assertEqual(str(context.exception), 'Test interrupt') + + # Verify the function was retried the expected number of times before interrupting + self.assertEqual(call_count, 10) diff --git a/tests/clients/test_timeout_interceptor.py b/tests/clients/test_timeout_interceptor.py new file mode 100644 index 00000000..79859b2e --- /dev/null +++ b/tests/clients/test_timeout_interceptor.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from unittest.mock import Mock, patch +from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor +from dapr.conf import settings + + +class DaprClientTimeoutInterceptorTests(unittest.TestCase): + def test_intercept_unary_unary_with_timeout(self): + continuation = Mock() + request = Mock() + client_call_details = Mock() + client_call_details.method = 'method' + client_call_details.timeout = 10 + client_call_details.metadata = 'metadata' + client_call_details.credentials = 'credentials' + client_call_details.wait_for_ready = 'wait_for_ready' + client_call_details.compression = 'compression' + + DaprClientTimeoutInterceptor().intercept_unary_unary( + continuation, client_call_details, request + ) + continuation.assert_called_once_with(client_call_details, request) + + @patch.object(settings, 'DAPR_API_TIMEOUT_SECONDS', 7) + def test_intercept_unary_unary_without_timeout(self): + continuation = Mock() + request = Mock() + client_call_details = Mock() + client_call_details.method = 'method' + client_call_details.timeout = None + client_call_details.metadata = 'metadata' + client_call_details.credentials = 'credentials' + client_call_details.wait_for_ready = 'wait_for_ready' + client_call_details.compression = 'compression' + + DaprClientTimeoutInterceptor().intercept_unary_unary( + continuation, client_call_details, request + ) + called_client_call_details = continuation.call_args[0][0] + self.assertEqual(7, called_client_call_details.timeout) diff --git a/tests/clients/test_timeout_interceptor_async.py b/tests/clients/test_timeout_interceptor_async.py new file mode 100644 index 00000000..d057df9f --- /dev/null +++ b/tests/clients/test_timeout_interceptor_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from unittest.mock import Mock, patch +from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync +from dapr.conf import settings + + +class DaprClientTimeoutInterceptorAsyncTests(unittest.TestCase): + def test_intercept_unary_unary_with_timeout(self): + continuation = Mock() + request = Mock() + client_call_details = Mock() + client_call_details.method = 'method' + client_call_details.timeout = 10 + client_call_details.metadata = 'metadata' + client_call_details.credentials = 'credentials' + client_call_details.wait_for_ready = 'wait_for_ready' + + DaprClientTimeoutInterceptorAsync().intercept_unary_unary( + continuation, client_call_details, request + ) + continuation.assert_called_once_with(client_call_details, request) + + @patch.object(settings, 'DAPR_API_TIMEOUT_SECONDS', 7) + def test_intercept_unary_unary_without_timeout(self): + continuation = Mock() + request = Mock() + client_call_details = Mock() + client_call_details.method = 'method' + client_call_details.timeout = None + client_call_details.metadata = 'metadata' + client_call_details.credentials = 'credentials' + client_call_details.wait_for_ready = 'wait_for_ready' + + DaprClientTimeoutInterceptorAsync().intercept_unary_unary( + continuation, client_call_details, request + ) + called_client_call_details = continuation.call_args[0][0] + self.assertEqual(7, called_client_call_details.timeout)