From 0c2245cb3c5c0b599c58e3c949075f1094da11e9 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Wed, 4 Dec 2024 18:43:57 -0800 Subject: [PATCH 1/5] WIP support for intermediate responses for LLM streaming --- neon_mq_connector/utils/client_utils.py | 32 +++++++++++++++++-------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/neon_mq_connector/utils/client_utils.py b/neon_mq_connector/utils/client_utils.py index 1cfe64e..3c82fe0 100644 --- a/neon_mq_connector/utils/client_utils.py +++ b/neon_mq_connector/utils/client_utils.py @@ -29,6 +29,8 @@ import uuid from threading import Event +from typing import Callable, Optional + from pika.channel import Channel from pika.exceptions import ProbableAccessDeniedError, StreamLostError from neon_mq_connector.connector import MQConnector @@ -60,7 +62,8 @@ def __init__(self, config: dict, service_name: str, vhost: str): def send_mq_request(vhost: str, request_data: dict, target_queue: str, response_queue: str = None, timeout: int = 30, - expect_response: bool = True) -> dict: + expect_response: bool = True, + stream_callback: Optional[Callable[[dict], None]] = None) -> dict: """ Sends a request to the MQ server and returns the response. :param vhost: vhost to target @@ -70,6 +73,7 @@ def send_mq_request(vhost: str, request_data: dict, target_queue: str, Generally should be blank :param timeout: time in seconds to wait for a response before timing out :param expect_response: boolean indicating whether a response is expected + :param stream_callback: Optional function to pass partial responses to :return: response to request """ response_queue = response_queue or uuid.uuid4().hex @@ -94,22 +98,30 @@ def handle_mq_response(channel: Channel, method, _, body): """ api_output = b64_to_dict(body) - # The Messagebus connector generates a unique `message_id` for each - # response message. Check context for the original one; otherwise, - # check in output directly as some APIs emit responses without a unique - # message_id + # Backwards-compat. handles `context` in response for raw `Message` + # objects sent across the MQ bus api_output_msg_id = \ api_output.get('context', api_output).get('mq', api_output).get('message_id') - # TODO: One of these specs should be deprecated if api_output_msg_id != api_output.get('message_id'): - LOG.debug(f"Handling message_id from response context") + # TODO: `context.mq` handling should be deprecated + LOG.warning(f"Handling message_id from response context") if api_output_msg_id == message_id: LOG.debug(f'MQ output: {api_output}') channel.basic_ack(delivery_tag=method.delivery_tag) - channel.close() - response_data.update(api_output) - response_event.set() + if api_output.get('_part'): + # Handle multi-part responses + if stream_callback: + # Pass each part to the stream callback method if defined + stream_callback(api_output) + if api_output.get('_is_final'): + # Always return final result + response_data.update(api_output) + else: + response_data.update(api_output) + if api_output.get('_is_final', True): + channel.close() + response_event.set() else: channel.basic_nack(delivery_tag=method.delivery_tag) LOG.debug(f"Ignoring {api_output_msg_id} waiting for {message_id}") From 34e2a84b09c8293a0585084ff954215e6f7fc96c Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Tue, 10 Dec 2024 12:55:38 -0800 Subject: [PATCH 2/5] Fix handling of multi-part `0`-indexed returns Add test coverage for `send_mq_request` with `stream_callback` --- neon_mq_connector/utils/client_utils.py | 2 +- tests/test_utils.py | 44 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/neon_mq_connector/utils/client_utils.py b/neon_mq_connector/utils/client_utils.py index 3c82fe0..9cf368c 100644 --- a/neon_mq_connector/utils/client_utils.py +++ b/neon_mq_connector/utils/client_utils.py @@ -109,7 +109,7 @@ def handle_mq_response(channel: Channel, method, _, body): if api_output_msg_id == message_id: LOG.debug(f'MQ output: {api_output}') channel.basic_ack(delivery_tag=method.delivery_tag) - if api_output.get('_part'): + if isinstance(api_output.get('_part'), int): # Handle multi-part responses if stream_callback: # Pass each part to the stream callback method if defined diff --git a/tests/test_utils.py b/tests/test_utils.py index cda9e04..c234b46 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -109,6 +109,27 @@ def respond(channel, method, _, body): properties=pika.BasicProperties(expiration='1000')) channel.basic_ack(delivery_tag=method.delivery_tag) + @staticmethod + def respond_multiple(channel, method, _, body): + request = b64_to_dict(body) + num_parts = request.get("num_parts", 3) + base_response = {"message_id": request["message_id"], + "success": True, + "request_data": request["data"]} + reply_channel = request.get("routing_key") + channel.queue_declare(queue=reply_channel) + response_text = "" + for i in range(num_parts): + response_text += f" {i}" + response = {**base_response, **{"response": response_text, + "_part": i, + "_is_final": i == num_parts - 1}} + channel.basic_publish(exchange='', + routing_key=reply_channel, + body=dict_to_b64(response), + properties=pika.BasicProperties(expiration='1000')) + channel.basic_ack(delivery_tag=method.delivery_tag) + @pytest.mark.usefixtures("rmq_instance") class TestClientUtils(unittest.TestCase): @@ -132,6 +153,11 @@ def setUp(self) -> None: INPUT_CHANNEL, self.test_connector.respond, auto_ack=False) + self.test_connector.register_consumer("neon_utils_test_multi", + vhost, + f"{INPUT_CHANNEL}-multi", + self.test_connector.respond_multiple, + auto_ack=False) self.test_connector.run_consumers() @classmethod @@ -156,6 +182,24 @@ def test_send_mq_request_spec_output_channel_valid(self): self.assertTrue(response["success"]) self.assertEqual(response["request_data"], request["data"]) + def test_multi_part_mq_response(self): + from neon_mq_connector.utils.client_utils import send_mq_request + request = {"data": time.time(), + "num_parts": 5} + target_queue = f"{INPUT_CHANNEL}-multi" + stream_callback = Mock() + response = send_mq_request("/neon_testing", request, target_queue, + stream_callback=stream_callback) + + self.assertEqual(stream_callback.call_count, request['num_parts'], + stream_callback.call_args_list) + + self.assertIsInstance(response, dict, response) + self.assertTrue(response.get("success"), response) + self.assertEqual(response["request_data"], request["data"]) + self.assertEqual(len(response['response'].split()), request['num_parts']) + self.assertTrue(response['_is_final']) + def test_multiple_mq_requests(self): from neon_mq_connector.utils.client_utils import send_mq_request responses = dict() From 162510ae6fda3903609d22e10705c9c829f81ec8 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Tue, 10 Dec 2024 13:25:33 -0800 Subject: [PATCH 3/5] Update documentation to describe multi-part response behavior Add test for `stream_callback` final response --- README.md | 23 ++++++++++++++++++++++- neon_mq_connector/utils/client_utils.py | 4 +++- tests/test_utils.py | 3 +++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 50ce009..484a5af 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,28 @@ A response may be sent via: properties=pika.BasicProperties(expiration='1000') ) ``` -Where `` is the queue to which the response will be published, and `data` is a `bytes` response (generally a `base64`-encoded `dict`). +Where `` is the queue to which the response will be published, and `data` +is a `bytes` response (generally a `base64`-encoded `dict`). + +#### Multi-part Responses +A callback function may choose to publish multiple response messages so the client +may receive partial responses as they are being generated. If multiple responses +will be returned, the following requirements must be met: +- Each response must be a dict with `_part` and `_is_final` keys. +- The final response must specify `_is_final=True` +- The final response *MUST NOT* require the client to handle partial responses + +## Client Requests +Most client applications will interact with services via `send_mq_request`. This +function will return a `dict` response to the input message. + +### Multi-part Responses +A caller may optionally include a `stream_callback` argument which may receive +partial responses if supported by the service generating the response. The +`stream_callback` will always be called with the final result that is returned +by `send_mq_request`. Keep in mind that the `timeout` param passed to +`send_mq_request` applies to the full response, so it may be desirable to increase +the timeout if ### [BETA] Asynchronous Consumers diff --git a/neon_mq_connector/utils/client_utils.py b/neon_mq_connector/utils/client_utils.py index 9cf368c..0a33369 100644 --- a/neon_mq_connector/utils/client_utils.py +++ b/neon_mq_connector/utils/client_utils.py @@ -71,7 +71,9 @@ def send_mq_request(vhost: str, request_data: dict, target_queue: str, :param target_queue: queue to post request to :param response_queue: optional queue to monitor for a response. Generally should be blank - :param timeout: time in seconds to wait for a response before timing out + :param timeout: time in seconds to wait for a complete response before + timing out. Note that in the event of a timeout, a partial response may + have been handled by `stream_callback`. :param expect_response: boolean indicating whether a response is expected :param stream_callback: Optional function to pass partial responses to :return: response to request diff --git a/tests/test_utils.py b/tests/test_utils.py index c234b46..b25d481 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -200,6 +200,9 @@ def test_multi_part_mq_response(self): self.assertEqual(len(response['response'].split()), request['num_parts']) self.assertTrue(response['_is_final']) + # Last callback is the same as the standard response + self.assertEqual(response, stream_callback.call_args[0][0]) + def test_multiple_mq_requests(self): from neon_mq_connector.utils.client_utils import send_mq_request responses = dict() From 7b3d4d8f6a72393aa558a4727dd4d6c61b110fd9 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Wed, 11 Dec 2024 13:51:09 -0800 Subject: [PATCH 4/5] Add time between multi-response emits to help responses be handled synchronously Add note to address synchronization of responses --- neon_mq_connector/utils/client_utils.py | 1 + tests/test_utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/neon_mq_connector/utils/client_utils.py b/neon_mq_connector/utils/client_utils.py index 0a33369..13f935a 100644 --- a/neon_mq_connector/utils/client_utils.py +++ b/neon_mq_connector/utils/client_utils.py @@ -112,6 +112,7 @@ def handle_mq_response(channel: Channel, method, _, body): LOG.debug(f'MQ output: {api_output}') channel.basic_ack(delivery_tag=method.delivery_tag) if isinstance(api_output.get('_part'), int): + # TODO: Consider forcing these to be passed to `stream_callback` synchronously # Handle multi-part responses if stream_callback: # Pass each part to the stream callback method if defined diff --git a/tests/test_utils.py b/tests/test_utils.py index b25d481..b0749ca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -128,6 +128,7 @@ def respond_multiple(channel, method, _, body): routing_key=reply_channel, body=dict_to_b64(response), properties=pika.BasicProperties(expiration='1000')) + time.sleep(0.5) # Used to ensure synchronous response handling channel.basic_ack(delivery_tag=method.delivery_tag) From d1a8ec9f878be57269e9b594daf03930dff2fdec Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Thu, 26 Dec 2024 12:06:31 -0800 Subject: [PATCH 5/5] Update readme per review --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 484a5af..8c62d13 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ A callback function may choose to publish multiple response messages so the clie may receive partial responses as they are being generated. If multiple responses will be returned, the following requirements must be met: - Each response must be a dict with `_part` and `_is_final` keys. +- `_part` is defined as a non-negative integer (the first response will specify `0`). - The final response must specify `_is_final=True` - The final response *MUST NOT* require the client to handle partial responses @@ -85,8 +86,9 @@ A caller may optionally include a `stream_callback` argument which may receive partial responses if supported by the service generating the response. The `stream_callback` will always be called with the final result that is returned by `send_mq_request`. Keep in mind that the `timeout` param passed to -`send_mq_request` applies to the full response, so it may be desirable to increase -the timeout if +`send_mq_request` applies to the full response, so the `timeout` value should +reflect the longest time it will take for a final response to be generated, plus +some margin. ### [BETA] Asynchronous Consumers