diff --git a/README.md b/README.md index 50ce009..8c62d13 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,30 @@ A response may be sent via: properties=pika.BasicProperties(expiration='1000') ) ``` -Where `` is the queue to which the response will be published, and `data` is a `bytes` response (generally a `base64`-encoded `dict`). +Where `` is the queue to which the response will be published, and `data` +is a `bytes` response (generally a `base64`-encoded `dict`). + +#### Multi-part Responses +A callback function may choose to publish multiple response messages so the client +may receive partial responses as they are being generated. If multiple responses +will be returned, the following requirements must be met: +- Each response must be a dict with `_part` and `_is_final` keys. +- `_part` is defined as a non-negative integer (the first response will specify `0`). +- The final response must specify `_is_final=True` +- The final response *MUST NOT* require the client to handle partial responses + +## Client Requests +Most client applications will interact with services via `send_mq_request`. This +function will return a `dict` response to the input message. + +### Multi-part Responses +A caller may optionally include a `stream_callback` argument which may receive +partial responses if supported by the service generating the response. The +`stream_callback` will always be called with the final result that is returned +by `send_mq_request`. Keep in mind that the `timeout` param passed to +`send_mq_request` applies to the full response, so the `timeout` value should +reflect the longest time it will take for a final response to be generated, plus +some margin. ### [BETA] Asynchronous Consumers diff --git a/neon_mq_connector/utils/client_utils.py b/neon_mq_connector/utils/client_utils.py index 1cfe64e..13f935a 100644 --- a/neon_mq_connector/utils/client_utils.py +++ b/neon_mq_connector/utils/client_utils.py @@ -29,6 +29,8 @@ import uuid from threading import Event +from typing import Callable, Optional + from pika.channel import Channel from pika.exceptions import ProbableAccessDeniedError, StreamLostError from neon_mq_connector.connector import MQConnector @@ -60,7 +62,8 @@ def __init__(self, config: dict, service_name: str, vhost: str): def send_mq_request(vhost: str, request_data: dict, target_queue: str, response_queue: str = None, timeout: int = 30, - expect_response: bool = True) -> dict: + expect_response: bool = True, + stream_callback: Optional[Callable[[dict], None]] = None) -> dict: """ Sends a request to the MQ server and returns the response. :param vhost: vhost to target @@ -68,8 +71,11 @@ def send_mq_request(vhost: str, request_data: dict, target_queue: str, :param target_queue: queue to post request to :param response_queue: optional queue to monitor for a response. Generally should be blank - :param timeout: time in seconds to wait for a response before timing out + :param timeout: time in seconds to wait for a complete response before + timing out. Note that in the event of a timeout, a partial response may + have been handled by `stream_callback`. :param expect_response: boolean indicating whether a response is expected + :param stream_callback: Optional function to pass partial responses to :return: response to request """ response_queue = response_queue or uuid.uuid4().hex @@ -94,22 +100,31 @@ def handle_mq_response(channel: Channel, method, _, body): """ api_output = b64_to_dict(body) - # The Messagebus connector generates a unique `message_id` for each - # response message. Check context for the original one; otherwise, - # check in output directly as some APIs emit responses without a unique - # message_id + # Backwards-compat. handles `context` in response for raw `Message` + # objects sent across the MQ bus api_output_msg_id = \ api_output.get('context', api_output).get('mq', api_output).get('message_id') - # TODO: One of these specs should be deprecated if api_output_msg_id != api_output.get('message_id'): - LOG.debug(f"Handling message_id from response context") + # TODO: `context.mq` handling should be deprecated + LOG.warning(f"Handling message_id from response context") if api_output_msg_id == message_id: LOG.debug(f'MQ output: {api_output}') channel.basic_ack(delivery_tag=method.delivery_tag) - channel.close() - response_data.update(api_output) - response_event.set() + if isinstance(api_output.get('_part'), int): + # TODO: Consider forcing these to be passed to `stream_callback` synchronously + # Handle multi-part responses + if stream_callback: + # Pass each part to the stream callback method if defined + stream_callback(api_output) + if api_output.get('_is_final'): + # Always return final result + response_data.update(api_output) + else: + response_data.update(api_output) + if api_output.get('_is_final', True): + channel.close() + response_event.set() else: channel.basic_nack(delivery_tag=method.delivery_tag) LOG.debug(f"Ignoring {api_output_msg_id} waiting for {message_id}") diff --git a/tests/test_utils.py b/tests/test_utils.py index cda9e04..b0749ca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -109,6 +109,28 @@ def respond(channel, method, _, body): properties=pika.BasicProperties(expiration='1000')) channel.basic_ack(delivery_tag=method.delivery_tag) + @staticmethod + def respond_multiple(channel, method, _, body): + request = b64_to_dict(body) + num_parts = request.get("num_parts", 3) + base_response = {"message_id": request["message_id"], + "success": True, + "request_data": request["data"]} + reply_channel = request.get("routing_key") + channel.queue_declare(queue=reply_channel) + response_text = "" + for i in range(num_parts): + response_text += f" {i}" + response = {**base_response, **{"response": response_text, + "_part": i, + "_is_final": i == num_parts - 1}} + channel.basic_publish(exchange='', + routing_key=reply_channel, + body=dict_to_b64(response), + properties=pika.BasicProperties(expiration='1000')) + time.sleep(0.5) # Used to ensure synchronous response handling + channel.basic_ack(delivery_tag=method.delivery_tag) + @pytest.mark.usefixtures("rmq_instance") class TestClientUtils(unittest.TestCase): @@ -132,6 +154,11 @@ def setUp(self) -> None: INPUT_CHANNEL, self.test_connector.respond, auto_ack=False) + self.test_connector.register_consumer("neon_utils_test_multi", + vhost, + f"{INPUT_CHANNEL}-multi", + self.test_connector.respond_multiple, + auto_ack=False) self.test_connector.run_consumers() @classmethod @@ -156,6 +183,27 @@ def test_send_mq_request_spec_output_channel_valid(self): self.assertTrue(response["success"]) self.assertEqual(response["request_data"], request["data"]) + def test_multi_part_mq_response(self): + from neon_mq_connector.utils.client_utils import send_mq_request + request = {"data": time.time(), + "num_parts": 5} + target_queue = f"{INPUT_CHANNEL}-multi" + stream_callback = Mock() + response = send_mq_request("/neon_testing", request, target_queue, + stream_callback=stream_callback) + + self.assertEqual(stream_callback.call_count, request['num_parts'], + stream_callback.call_args_list) + + self.assertIsInstance(response, dict, response) + self.assertTrue(response.get("success"), response) + self.assertEqual(response["request_data"], request["data"]) + self.assertEqual(len(response['response'].split()), request['num_parts']) + self.assertTrue(response['_is_final']) + + # Last callback is the same as the standard response + self.assertEqual(response, stream_callback.call_args[0][0]) + def test_multiple_mq_requests(self): from neon_mq_connector.utils.client_utils import send_mq_request responses = dict()