From ce3b84c67a31785f45eb46f154ef08af6edc9a36 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Wed, 9 Oct 2024 14:43:49 -0400 Subject: [PATCH] fix: streaming for sync REST API calls (#2204) --- .../%sub/services/%service/_shared_macros.j2 | 12 ++++++++-- .../services/%service/transports/rest.py.j2 | 2 +- .../%service/transports/rest_asyncio.py.j2 | 2 +- .../cloud_redis/transports/rest_asyncio.py | 24 +++++++++---------- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 index c35c897500..cc795cc91a 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/_shared_macros.j2 @@ -152,7 +152,7 @@ def _get_http_options(): return http_options {% endmacro %} -{% macro response_method(body_spec, is_async=False) %} +{% macro response_method(body_spec, is_async=False, is_streaming_method=False) %} {% set async_prefix = "async " if is_async else "" %} {% set await_prefix = "await " if is_async else "" %} @staticmethod @@ -177,6 +177,14 @@ def _get_http_options(): {% if body_spec %} data=body, {% endif %} + {% if not is_async and is_streaming_method %} + {# NOTE: The underlying `requests` library used for making a sync request + # requires us to set `stream=True` to avoid loading the entire response + # into memory at once. For an async request, given its nature where it + # reads data chunk by chunk, this is not required. + #} + stream=True, + {% endif %} ) return response {% endmacro %} @@ -400,7 +408,7 @@ class _{{ name }}(_Base{{ service.name }}RestTransport._Base{{name}}, {{ async_m return hash("{{ async_method_name_prefix }}{{ service.name }}RestTransport.{{ name }}") {% set body_spec = api.mixin_http_options["{}".format(name)][0].body %} - {{ response_method(body_spec) | indent(4) }} + {{ response_method(body_spec, is_async=is_async, is_streaming_method=None) | indent(4) }} {{ async_prefix }}def __call__(self, request: {{ sig.request_type }}, *, diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 796aa99325..6bdbbbcbc4 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -196,7 +196,7 @@ class {{service.name}}RestTransport(_Base{{ service.name }}RestTransport): {% if method.http_options and not method.client_streaming %} {% set body_spec = method.http_options[0].body %} - {{ shared_macros.response_method(body_spec)|indent(8) }} + {{ shared_macros.response_method(body_spec, is_async=False, is_streaming_method=method.server_streaming)|indent(8) }} {% endif %}{# method.http_options and not method.client_streaming #} def __call__(self, diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 index be65f9db68..6ae4e67786 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest_asyncio.py.j2 @@ -155,7 +155,7 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport): {# TODO(https://github.com/googleapis/gapic-generator-python/issues/2169): Implement client streaming method. #} {% if method.http_options and not method.client_streaming %} {% set body_spec = method.http_options[0].body %} - {{ shared_macros.response_method(body_spec, is_async=True)|indent(8) }} + {{ shared_macros.response_method(body_spec, is_async=True, is_streaming_method=None)|indent(8) }} {% endif %}{# method.http_options and not method.client_streaming and not method.paged_result_field #} async def __call__(self, diff --git a/tests/integration/goldens/redis/google/cloud/redis_v1/services/cloud_redis/transports/rest_asyncio.py b/tests/integration/goldens/redis/google/cloud/redis_v1/services/cloud_redis/transports/rest_asyncio.py index 9b72bcc6b1..9c461e3c1e 100755 --- a/tests/integration/goldens/redis/google/cloud/redis_v1/services/cloud_redis/transports/rest_asyncio.py +++ b/tests/integration/goldens/redis/google/cloud/redis_v1/services/cloud_redis/transports/rest_asyncio.py @@ -1670,7 +1670,7 @@ def __hash__(self): return hash("AsyncCloudRedisRestTransport.GetLocation") @staticmethod - def _get_response( + async def _get_response( host, metadata, query_params, @@ -1683,7 +1683,7 @@ def _get_response( method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' - response = getattr(session, method)( + response = await getattr(session, method)( "{host}{uri}".format(host=host, uri=uri), timeout=timeout, headers=headers, @@ -1747,7 +1747,7 @@ def __hash__(self): return hash("AsyncCloudRedisRestTransport.ListLocations") @staticmethod - def _get_response( + async def _get_response( host, metadata, query_params, @@ -1760,7 +1760,7 @@ def _get_response( method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' - response = getattr(session, method)( + response = await getattr(session, method)( "{host}{uri}".format(host=host, uri=uri), timeout=timeout, headers=headers, @@ -1824,7 +1824,7 @@ def __hash__(self): return hash("AsyncCloudRedisRestTransport.CancelOperation") @staticmethod - def _get_response( + async def _get_response( host, metadata, query_params, @@ -1837,7 +1837,7 @@ def _get_response( method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' - response = getattr(session, method)( + response = await getattr(session, method)( "{host}{uri}".format(host=host, uri=uri), timeout=timeout, headers=headers, @@ -1894,7 +1894,7 @@ def __hash__(self): return hash("AsyncCloudRedisRestTransport.DeleteOperation") @staticmethod - def _get_response( + async def _get_response( host, metadata, query_params, @@ -1907,7 +1907,7 @@ def _get_response( method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' - response = getattr(session, method)( + response = await getattr(session, method)( "{host}{uri}".format(host=host, uri=uri), timeout=timeout, headers=headers, @@ -1964,7 +1964,7 @@ def __hash__(self): return hash("AsyncCloudRedisRestTransport.GetOperation") @staticmethod - def _get_response( + async def _get_response( host, metadata, query_params, @@ -1977,7 +1977,7 @@ def _get_response( method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' - response = getattr(session, method)( + response = await getattr(session, method)( "{host}{uri}".format(host=host, uri=uri), timeout=timeout, headers=headers, @@ -2041,7 +2041,7 @@ def __hash__(self): return hash("AsyncCloudRedisRestTransport.ListOperations") @staticmethod - def _get_response( + async def _get_response( host, metadata, query_params, @@ -2054,7 +2054,7 @@ def _get_response( method = transcoded_request['method'] headers = dict(metadata) headers['Content-Type'] = 'application/json' - response = getattr(session, method)( + response = await getattr(session, method)( "{host}{uri}".format(host=host, uri=uri), timeout=timeout, headers=headers,