Skip to content

Commit

Permalink
Adds async
Browse files Browse the repository at this point in the history
Signed-off-by: Elena Kolevska <elena@kolevska.com>
  • Loading branch information
elena-kolevska committed Oct 9, 2024
1 parent 045ca7c commit 59d3c73
Show file tree
Hide file tree
Showing 10 changed files with 485 additions and 134 deletions.
90 changes: 68 additions & 22 deletions dapr/aio/clients/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from warnings import warn

from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any
from typing import Callable, Dict, Optional, Text, Union, Sequence, List, Any, Awaitable
from typing_extensions import Self

from google.protobuf.message import Message as GrpcMessage
Expand All @@ -39,12 +39,14 @@
AioRpcError,
)

from dapr.aio.clients.grpc.subscription import Subscription
from dapr.clients.exceptions import DaprInternalError, DaprGrpcError
from dapr.clients.grpc._crypto import EncryptOptions, DecryptOptions
from dapr.clients.grpc._state import StateOptions, StateItem
from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus
from dapr.clients.health import DaprHealth
from dapr.clients.retry import RetryPolicy
from dapr.common.pubsub.subscription import StreamInactiveError
from dapr.conf.helpers import GrpcEndpoint
from dapr.conf import settings
from dapr.proto import api_v1, api_service_v1, common_v1
Expand Down Expand Up @@ -74,27 +76,14 @@
BindingRequest,
TransactionalStateOperation,
)
from dapr.clients.grpc._response import (
BindingResponse,
DaprResponse,
GetSecretResponse,
GetBulkSecretResponse,
GetMetadataResponse,
InvokeMethodResponse,
UnlockResponseStatus,
StateResponse,
BulkStatesResponse,
BulkStateItem,
ConfigurationResponse,
QueryResponse,
QueryResponseItem,
RegisteredComponents,
ConfigurationWatcher,
TryLockResponse,
UnlockResponse,
GetWorkflowResponse,
StartWorkflowResponse,
)
from dapr.clients.grpc._response import (BindingResponse, DaprResponse, GetSecretResponse,
GetBulkSecretResponse, GetMetadataResponse,
InvokeMethodResponse, UnlockResponseStatus, StateResponse,
BulkStatesResponse, BulkStateItem, ConfigurationResponse,
QueryResponse, QueryResponseItem, RegisteredComponents,
ConfigurationWatcher, TryLockResponse, UnlockResponse,
GetWorkflowResponse, StartWorkflowResponse,
TopicEventResponse, )


class DaprGrpcClientAsync:
Expand Down Expand Up @@ -482,6 +471,63 @@ async def publish_event(

return DaprResponse(await call.initial_metadata())

async def subscribe(self, pubsub_name: str, topic: str, metadata: Optional[dict] = None,
dead_letter_topic: Optional[str] = None, ) -> Subscription:
"""
Subscribe to a topic with a bidirectional stream
Args:
pubsub_name (str): The name of the pubsub component.
topic (str): The name of the topic.
metadata (Optional[dict]): Additional metadata for the subscription.
dead_letter_topic (Optional[str]): Name of the dead-letter topic.
Returns:
Subscription: The Subscription object managing the stream.
"""
subscription = Subscription(self._stub, pubsub_name, topic, metadata,
dead_letter_topic)
await subscription.start()
return subscription

async def subscribe_with_handler(self, pubsub_name: str, topic: str,
handler_fn: Callable[..., TopicEventResponse], metadata: Optional[dict] = None,
dead_letter_topic: Optional[str] = None, ) -> Callable[[], Awaitable[None]]:
"""
Subscribe to a topic with a bidirectional stream and a message handler function
Args:
pubsub_name (str): The name of the pubsub component.
topic (str): The name of the topic.
handler_fn (Callable[..., TopicEventResponse]): The function to call when a message is received.
metadata (Optional[dict]): Additional metadata for the subscription.
dead_letter_topic (Optional[str]): Name of the dead-letter topic.
Returns:
Callable[[], Awaitable[None]]: An async function to close the subscription.
"""
subscription = await self.subscribe(pubsub_name, topic, metadata, dead_letter_topic)

async def stream_messages(sub: Subscription):
while True:
try:
message = await sub.next_message()
if message:
response = await handler_fn(message)
if response:
await subscription._respond(message, response.status)
else:
continue
except StreamInactiveError:
break

async def close_subscription():
await subscription.close()

asyncio.create_task(stream_messages(subscription))

return close_subscription

async def get_state(
self,
store_name: str,
Expand Down
109 changes: 109 additions & 0 deletions dapr/aio/clients/grpc/subscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import asyncio
from grpc import StatusCode
from grpc.aio import AioRpcError

from dapr.clients.grpc._response import TopicEventResponse
from dapr.clients.health import DaprHealth
from dapr.common.pubsub.subscription import StreamInactiveError, SubscriptionMessage
from dapr.proto import api_v1, appcallback_v1

class Subscription:

def __init__(self, stub, pubsub_name, topic, metadata=None, dead_letter_topic=None):
self._stub = stub
self._pubsub_name = pubsub_name
self._topic = topic
self._metadata = metadata or {}
self._dead_letter_topic = dead_letter_topic or ''
self._stream = None
self._send_queue = asyncio.Queue()
self._stream_active = asyncio.Event()

async def start(self):
async def outgoing_request_iterator():
try:
initial_request = api_v1.SubscribeTopicEventsRequestAlpha1(
initial_request=api_v1.SubscribeTopicEventsRequestInitialAlpha1(
pubsub_name=self._pubsub_name,
topic=self._topic,
metadata=self._metadata,
dead_letter_topic=self._dead_letter_topic,
)
)
yield initial_request

while self._stream_active.is_set():
try:
response = await asyncio.wait_for(self._send_queue.get(), timeout=1.0)
yield response
except asyncio.TimeoutError:
continue
except Exception as e:
raise Exception(f'Error while writing to stream: {e}')

self._stream = self._stub.SubscribeTopicEventsAlpha1(outgoing_request_iterator())
self._stream_active.set()
await self._stream.read() # discard the initial message

async def reconnect_stream(self):
await self.close()
DaprHealth.wait_until_ready()
print('Attempting to reconnect...')
await self.start()

async def next_message(self):
if not self._stream_active.is_set():
raise StreamInactiveError('Stream is not active')

try:
if self._stream is not None:
message = await self._stream.read()
if message is None:
return None
return SubscriptionMessage(message.event_message)
except AioRpcError as e:
if e.code() == StatusCode.UNAVAILABLE:
print(f'gRPC error while reading from stream: {e.details()}, '
f'Status Code: {e.code()}. '
f'Attempting to reconnect...')
await self.reconnect_stream()
elif e.code() != StatusCode.CANCELLED:
raise Exception(f'gRPC error while reading from subscription stream: {e.details()} '
f'Status Code: {e.code()}')
except Exception as e:
raise Exception(f'Error while fetching message: {e}')

return None

async def _respond(self, message, status):
try:
status = appcallback_v1.TopicEventResponse(status=status.value)
response = api_v1.SubscribeTopicEventsRequestProcessedAlpha1(
id=message.id(), status=status
)
msg = api_v1.SubscribeTopicEventsRequestAlpha1(event_processed=response)
if not self._stream_active.is_set():
raise StreamInactiveError('Stream is not active')
await self._send_queue.put(msg)
except Exception as e:
print(f"Can't send message on inactive stream: {e}")

async def respond_success(self, message):
await self._respond(message, TopicEventResponse('success').status)

async def respond_retry(self, message):
await self._respond(message, TopicEventResponse('retry').status)

async def respond_drop(self, message):
await self._respond(message, TopicEventResponse('drop').status)

async def close(self):
if self._stream:
try:
self._stream.cancel()
self._stream_active.clear()
except AioRpcError as e:
if e.code() != StatusCode.CANCELLED:
raise Exception(f'Error while closing stream: {e}')
except Exception as e:
raise Exception(f'Error while closing stream: {e}')
4 changes: 2 additions & 2 deletions dapr/clients/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def subscribe_with_handler(
Args:
pubsub_name (str): The name of the pubsub component.
topic (str): The name of the topic.
handler_fn (Callable[..., TopicEventResponseStatus]): The function to call when a message is received.
handler_fn (Callable[..., TopicEventResponse]): The function to call when a message is received.
metadata (Optional[MetadataTuple]): Additional metadata for the subscription.
dead_letter_topic (Optional[str]): Name of the dead-letter topic.
timeout (Optional[int]): The time in seconds to wait for a message before returning None
Expand All @@ -540,7 +540,7 @@ def stream_messages(sub):
# Process the message
response = handler_fn(message)
if response:
subscription._respond(message, response)
subscription.respond(message, response.status)
else:
# No message received
continue
Expand Down
102 changes: 6 additions & 96 deletions dapr/clients/grpc/subscription.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import json

from google.protobuf.json_format import MessageToDict
from grpc import RpcError, StatusCode, Call # type: ignore

from dapr.clients.grpc._response import TopicEventResponse
from dapr.clients.health import DaprHealth
from dapr.common.pubsub.subscription import StreamInactiveError, SubscriptionMessage
from dapr.proto import api_v1, appcallback_v1
import queue
import threading
from typing import Optional, Union
from typing import Optional

from dapr.proto.runtime.v1.appcallback_pb2 import TopicEventRequest


class Subscription:
SUCCESS = TopicEventResponse('success').status
RETRY = TopicEventResponse('retry').status
DROP = TopicEventResponse('drop').status

def __init__(self, stub, pubsub_name, topic, metadata=None, dead_letter_topic=None):
self._stub = stub
Expand Down Expand Up @@ -102,7 +96,7 @@ def next_message(self):

return None

def _respond(self, message, status):
def respond(self, message, status):
try:
status = appcallback_v1.TopicEventResponse(status=status.value)
response = api_v1.SubscribeTopicEventsRequestProcessedAlpha1(
Expand All @@ -116,13 +110,13 @@ def _respond(self, message, status):
print(f"Can't send message on inactive stream: {e}")

def respond_success(self, message):
self._respond(message, self.SUCCESS)
self.respond(message, TopicEventResponse('success').status)

def respond_retry(self, message):
self._respond(message, self.RETRY)
self.respond(message, TopicEventResponse('retry').status)

def respond_drop(self, message):
self._respond(message, self.DROP)
self.respond(message, TopicEventResponse('drop').status)

def _set_stream_active(self):
with self._stream_lock:
Expand All @@ -146,87 +140,3 @@ def close(self):
raise Exception(f'Error while closing stream: {e}')
except Exception as e:
raise Exception(f'Error while closing stream: {e}')


class SubscriptionMessage:
def __init__(self, msg: TopicEventRequest):
self._id: str = msg.id
self._source: str = msg.source
self._type: str = msg.type
self._spec_version: str = msg.spec_version
self._data_content_type: str = msg.data_content_type
self._topic: str = msg.topic
self._pubsub_name: str = msg.pubsub_name
self._raw_data: bytes = msg.data
self._data: Optional[Union[dict, str]] = None

try:
self._extensions = MessageToDict(msg.extensions)
except Exception as e:
self._extensions = {}
print(f'Error parsing extensions: {e}')

# Parse the content based on its media type
if self._raw_data and len(self._raw_data) > 0:
self._parse_data_content()

def id(self):
return self._id

def source(self):
return self._source

def type(self):
return self._type

def spec_version(self):
return self._spec_version

def data_content_type(self):
return self._data_content_type

def topic(self):
return self._topic

def pubsub_name(self):
return self._pubsub_name

def raw_data(self):
return self._raw_data

def extensions(self):
return self._extensions

def data(self):
return self._data

def _parse_data_content(self):
try:
if self._data_content_type == 'application/json':
try:
self._data = json.loads(self._raw_data)
except json.JSONDecodeError:
print(f'Error parsing json message data from topic {self._topic}')
pass # If JSON parsing fails, keep `data` as None
elif self._data_content_type == 'text/plain':
# Assume UTF-8 encoding
try:
self._data = self._raw_data.decode('utf-8')
except UnicodeDecodeError:
print(f'Error decoding message data from topic {self._topic} as UTF-8')
elif self._data_content_type.startswith(
'application/'
) and self._data_content_type.endswith('+json'):
# Handle custom JSON-based media types (e.g., application/vnd.api+json)
try:
self._data = json.loads(self._raw_data)
except json.JSONDecodeError:
print(f'Error parsing json message data from topic {self._topic}')
pass # If JSON parsing fails, keep `data` as None
except Exception as e:
# Log or handle any unexpected exceptions
print(f'Error parsing media type: {e}')


class StreamInactiveError(Exception):
pass
Loading

0 comments on commit 59d3c73

Please sign in to comment.