Skip to content

Commit

Permalink
feat: add client debug logging support for async gRPC (#2291)
Browse files Browse the repository at this point in the history
  • Loading branch information
ohmayr authored Dec 11, 2024
1 parent dddf797 commit f45935a
Show file tree
Hide file tree
Showing 17 changed files with 867 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "delete_operation" not in self._stubs:
self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
self._stubs["delete_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/DeleteOperation",
request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
response_deserializer=None,
Expand All @@ -52,7 +52,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "cancel_operation" not in self._stubs:
self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
self._stubs["cancel_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/CancelOperation",
request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
response_deserializer=None,
Expand All @@ -72,7 +72,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "wait_operation" not in self._stubs:
self._stubs["wait_operation"] = self.grpc_channel.unary_unary(
self._stubs["wait_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/WaitOperation",
request_serializer=operations_pb2.WaitOperationRequest.SerializeToString,
response_deserializer=None,
Expand All @@ -92,7 +92,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_operation" not in self._stubs:
self._stubs["get_operation"] = self.grpc_channel.unary_unary(
self._stubs["get_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/GetOperation",
request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
response_deserializer=operations_pb2.Operation.FromString,
Expand All @@ -112,7 +112,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "list_operations" not in self._stubs:
self._stubs["list_operations"] = self.grpc_channel.unary_unary(
self._stubs["list_operations"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/ListOperations",
request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
response_deserializer=operations_pb2.ListOperationsResponse.FromString,
Expand All @@ -136,7 +136,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "list_locations" not in self._stubs:
self._stubs["list_locations"] = self.grpc_channel.unary_unary(
self._stubs["list_locations"] = self._logged_channel.unary_unary(
"/google.cloud.location.Locations/ListLocations",
request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
response_deserializer=locations_pb2.ListLocationsResponse.FromString,
Expand All @@ -156,7 +156,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_location" not in self._stubs:
self._stubs["get_location"] = self.grpc_channel.unary_unary(
self._stubs["get_location"] = self._logged_channel.unary_unary(
"/google.cloud.location.Locations/GetLocation",
request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
response_deserializer=locations_pb2.Location.FromString,
Expand Down Expand Up @@ -188,7 +188,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "set_iam_policy" not in self._stubs:
self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
self._stubs["set_iam_policy"] = self._logged_channel.unary_unary(
"/google.iam.v1.IAMPolicy/SetIamPolicy",
request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
response_deserializer=policy_pb2.Policy.FromString,
Expand Down Expand Up @@ -216,7 +216,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_iam_policy" not in self._stubs:
self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
self._stubs["get_iam_policy"] = self._logged_channel.unary_unary(
"/google.iam.v1.IAMPolicy/GetIamPolicy",
request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
response_deserializer=policy_pb2.Policy.FromString,
Expand Down Expand Up @@ -246,7 +246,7 @@
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "test_iam_permissions" not in self._stubs:
self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
self._stubs["test_iam_permissions"] = self._logged_channel.unary_unary(
"/google.iam.v1.IAMPolicy/TestIamPermissions",
request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

{% block content %}

import json
import logging as std_logging
import pickle
import warnings
Expand Down Expand Up @@ -69,7 +70,12 @@ class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO
elif isinstance(request, google.protobuf.message.Message):
request_payload = MessageToJson(request)
else:
request_payload = f"{type(result).__name__}: {pickle.dumps(request)}"
request_payload = f"{type(request).__name__}: {pickle.dumps(request)}"

request_metadata = {
key: value.decode("utf-8") if isinstance(value, bytes) else value
for key, value in request_metadata
}
grpc_request = {
"payload": request_payload,
"requestMethod": "grpc",
Expand All @@ -90,7 +96,7 @@ class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO
if logging_enabled: # pragma: NO COVER
response_metadata = response.trailing_metadata()
# Convert gRPC metadata `<class 'grpc.aio._metadata.Metadata'>` to list of tuples
metadata = dict([(k, v) for k, v in response_metadata]) if response_metadata else None
metadata = dict([(k, str(v)) for k, v in response_metadata]) if response_metadata else None
result = response.result()
if isinstance(result, proto.Message):
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2293): Investigate if we can improve this logic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
{% import "%namespace/%name_%version/%sub/services/%service/_shared_macros.j2" as shared_macros %}

import inspect
import json
import pickle
import logging as std_logging
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union

Expand All @@ -16,8 +19,11 @@ from google.api_core import operations_v1
{% endif %}
from google.auth import credentials as ga_credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.protobuf.json_format import MessageToJson
import google.protobuf.message

import grpc # type: ignore
import proto # type: ignore
from grpc.experimental import aio # type: ignore

{% filter sort_lines %}
Expand Down Expand Up @@ -47,6 +53,81 @@ from google.longrunning import operations_pb2 # type: ignore
from .base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .grpc import {{ service.name }}GrpcTransport

try:
from google.api_core import client_logging # type: ignore
CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER
except ImportError: # pragma: NO COVER
CLIENT_LOGGING_SUPPORTED = False

_LOGGER = std_logging.getLogger(__name__)


class _LoggingClientAIOInterceptor(grpc.aio.UnaryUnaryClientInterceptor): # pragma: NO COVER
async def intercept_unary_unary(self, continuation, client_call_details, request):
logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(std_logging.DEBUG)
if logging_enabled: # pragma: NO COVER
request_metadata = client_call_details.metadata
if isinstance(request, proto.Message):
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2293): Investigate if we can improve this logic
or wait for next gen protobuf.
#}
request_payload = type(request).to_json(request)
elif isinstance(request, google.protobuf.message.Message):
request_payload = MessageToJson(request)
else:
request_payload = f"{type(request).__name__}: {pickle.dumps(request)}"

request_metadata = {
key: value.decode("utf-8") if isinstance(value, bytes) else value
for key, value in request_metadata
}
grpc_request = {
"payload": request_payload,
"requestMethod": "grpc",
"metadata": dict(request_metadata),
}
_LOGGER.debug(
f"Sending request for {client_call_details.method}",
extra = {
"serviceName": "{{ service.meta.address.proto }}",
"rpcName": str(client_call_details.method),
"request": grpc_request,
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2275): logging `metadata` seems repetitive and may need to be cleaned up. We're including it within "request" for consistency with REST transport.' #}
"metadata": grpc_request["metadata"],
},
)
response = await continuation(client_call_details, request)
if logging_enabled: # pragma: NO COVER
response_metadata = await response.trailing_metadata()
# Convert gRPC metadata `<class 'grpc.aio._metadata.Metadata'>` to list of tuples
metadata = dict([(k, str(v)) for k, v in response_metadata]) if response_metadata else None
result = await response
if isinstance(result, proto.Message):
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2293): Investigate if we can improve this logic
or wait for next gen protobuf.
#}
response_payload = type(result).to_json(result)
elif isinstance(result, google.protobuf.message.Message):
response_payload = MessageToJson(result)
else:
response_payload = f"{type(result).__name__}: {pickle.dumps(result)}"
grpc_response = {
"payload": response_payload,
"metadata": metadata,
"status": "OK",
}
_LOGGER.debug(
f"Received response to rpc {client_call_details.method}.",
extra = {
"serviceName": "{{ service.meta.address.proto }}",
"rpcName": str(client_call_details.method),
"response": grpc_response,
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2275): logging `metadata` seems repetitive and may need to be cleaned up. We're including it within "request" for consistency with REST transport.' #}
"metadata": grpc_response["metadata"],
},
)
return response


class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
"""gRPC AsyncIO backend transport for {{ service.name }}.
Expand Down Expand Up @@ -242,8 +323,11 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
],
)

# Wrap messages. This must be done after self._grpc_channel exists
self._interceptor = _LoggingClientAIOInterceptor()
self._grpc_channel._unary_unary_interceptors.append(self._interceptor)
self._logged_channel = self._grpc_channel
self._wrap_with_kind = "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters
# Wrap messages. This must be done after self._logged_channel exists
self._prep_wrapped_messages(client_info)

@property
Expand All @@ -267,7 +351,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# Quick check: Only create a new client if we do not already have one.
if self._operations_client is None:
self._operations_client = operations_v1.OperationsAsyncClient(
self.grpc_channel
self._logged_channel
)

# Return the client from cache.
Expand Down Expand Up @@ -297,7 +381,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if '{{ method.transport_safe_name|snake_case }}' not in self._stubs:
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self.grpc_channel.{{ method.grpc_stub_type }}(
self._stubs['{{ method.transport_safe_name|snake_case }}'] = self._logged_channel.{{ method.grpc_stub_type }}(
'/{{ '.'.join(method.meta.address.package) }}.{{ service.name }}/{{ method.name }}',
request_serializer={{ method.input.ident }}.{% if method.input.ident.python_import.module.endswith('_pb2') %}SerializeToString{% else %}serialize{% endif %},
response_deserializer={{ method.output.ident }}.{% if method.output.ident.python_import.module.endswith('_pb2') %}FromString{% else %}deserialize{% endif %},
Expand Down Expand Up @@ -325,7 +409,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "set_iam_policy" not in self._stubs:
self._stubs["set_iam_policy"] = self.grpc_channel.unary_unary(
self._stubs["set_iam_policy"] = self._logged_channel.unary_unary(
"/google.iam.v1.IAMPolicy/SetIamPolicy",
request_serializer=iam_policy_pb2.SetIamPolicyRequest.SerializeToString,
response_deserializer=policy_pb2.Policy.FromString,
Expand All @@ -351,7 +435,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_iam_policy" not in self._stubs:
self._stubs["get_iam_policy"] = self.grpc_channel.unary_unary(
self._stubs["get_iam_policy"] = self._logged_channel.unary_unary(
"/google.iam.v1.IAMPolicy/GetIamPolicy",
request_serializer=iam_policy_pb2.GetIamPolicyRequest.SerializeToString,
response_deserializer=policy_pb2.Policy.FromString,
Expand Down Expand Up @@ -380,7 +464,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "test_iam_permissions" not in self._stubs:
self._stubs["test_iam_permissions"] = self.grpc_channel.unary_unary(
self._stubs["test_iam_permissions"] = self._logged_channel.unary_unary(
"/google.iam.v1.IAMPolicy/TestIamPermissions",
request_serializer=iam_policy_pb2.TestIamPermissionsRequest.SerializeToString,
response_deserializer=iam_policy_pb2.TestIamPermissionsResponse.FromString,
Expand All @@ -393,7 +477,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
{{ shared_macros.wrap_async_method_macro()|indent(4) }}

def close(self):
return self.grpc_channel.close()
return self._logged_channel.close()

@property
def kind(self) -> str:
Expand All @@ -405,4 +489,4 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
__all__ = (
'{{ service.name }}GrpcAsyncIOTransport',
)
{% endblock %}
{% endblock %}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging as std_logging
import pickle
import warnings
Expand Down Expand Up @@ -54,7 +55,12 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
elif isinstance(request, google.protobuf.message.Message):
request_payload = MessageToJson(request)
else:
request_payload = f"{type(result).__name__}: {pickle.dumps(request)}"
request_payload = f"{type(request).__name__}: {pickle.dumps(request)}"

request_metadata = {
key: value.decode("utf-8") if isinstance(value, bytes) else value
for key, value in request_metadata
}
grpc_request = {
"payload": request_payload,
"requestMethod": "grpc",
Expand All @@ -74,7 +80,7 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
if logging_enabled: # pragma: NO COVER
response_metadata = response.trailing_metadata()
# Convert gRPC metadata `<class 'grpc.aio._metadata.Metadata'>` to list of tuples
metadata = dict([(k, v) for k, v in response_metadata]) if response_metadata else None
metadata = dict([(k, str(v)) for k, v in response_metadata]) if response_metadata else None
result = response.result()
if isinstance(result, proto.Message):
response_payload = type(result).to_json(result)
Expand Down Expand Up @@ -1019,7 +1025,7 @@ def get_operation(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_operation" not in self._stubs:
self._stubs["get_operation"] = self.grpc_channel.unary_unary(
self._stubs["get_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/GetOperation",
request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
response_deserializer=operations_pb2.Operation.FromString,
Expand Down
Loading

0 comments on commit f45935a

Please sign in to comment.