Skip to content

Commit

Permalink
Flattening and LRO module conflicts (#332)
Browse files Browse the repository at this point in the history
Flattened method fields generate names in client methods. Barring a central authority, any source of names generates the possibility of name collisions and requires disambiguation.
Before this fix, flattened fields could generate name collisions with imported modules. This commit adds field names to the context with which names are disambiguated and subjects LRO operation info structures to the same collision avoidance logic as other rich data types.

Unblocks autogenerated unit tests for Datalabeling API.
  • Loading branch information
software-dov authored Mar 4, 2020
1 parent 5dc9c00 commit adcec6a
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 45 deletions.
55 changes: 29 additions & 26 deletions gapic/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,21 @@ def names(self) -> FrozenSet[str]:
# Add names of all enums, messages, and fields.
answer: Set[str] = {e.name for e in self.all_enums.values()}
for message in self.all_messages.values():
answer = answer.union({f.name for f in message.fields.values()})
answer.update(f.name for f in message.fields.values())
answer.add(message.name)

# Identify any import module names where the same module name is used
# from distinct packages.
modules: Dict[str, Set[str]] = {}
for t in chain(*[m.field_types for m in self.all_messages.values()]):
modules.setdefault(t.ident.module, set())
modules[t.ident.module].add(t.ident.package)
for module_name, packages in modules.items():
if len(packages) > 1:
answer.add(module_name)
modules: Dict[str, Set[str]] = collections.defaultdict(set)
for m in self.all_messages.values():
for t in m.field_types:
modules[t.ident.module].add(t.ident.package)

answer.update(
module_name
for module_name, packages in modules.items()
if len(packages) > 1
)

# Return the set of collision names.
return frozenset(answer)
Expand Down Expand Up @@ -462,24 +465,24 @@ def proto(self) -> Proto:
return naive

# Return a context-aware proto object.
# Note: The services bind to themselves, because services get their
# own output files.
return dataclasses.replace(naive,
all_enums=collections.OrderedDict([
(k, v.with_context(collisions=naive.names))
for k, v in naive.all_enums.items()
]),
all_messages=collections.OrderedDict([
(k, v.with_context(collisions=naive.names))
for k, v in naive.all_messages.items()
]),
services=collections.OrderedDict([
(k, v.with_context(collisions=v.names))
for k, v in naive.services.items()
]),
meta=naive.meta.with_context(
collisions=naive.names),
)
return dataclasses.replace(
naive,
all_enums=collections.OrderedDict(
(k, v.with_context(collisions=naive.names))
for k, v in naive.all_enums.items()
),
all_messages=collections.OrderedDict(
(k, v.with_context(collisions=naive.names))
for k, v in naive.all_messages.items()
),
services=collections.OrderedDict(
# Note: services bind to themselves because services get their
# own output files.
(k, v.with_context(collisions=v.names))
for k, v in naive.services.items()
),
meta=naive.meta.with_context(collisions=naive.names),
)

@cached_property
def api_enums(self) -> Mapping[str, wrappers.EnumType]:
Expand Down
23 changes: 15 additions & 8 deletions gapic/schema/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,16 @@ def module_alias(self) -> str:
to users (albeit looking auto-generated).
"""
if self.module in self.collisions:
return '_'.join((
''.join([i[0] for i in self.package
if i != self.api_naming.version]),
self.module,
))
return '_'.join(
(
''.join(
i[0]
for i in self.package
if i != self.api_naming.version
),
self.module,
)
)
return ''

@property
Expand Down Expand Up @@ -161,7 +166,8 @@ def child(self, child_name: str, path: Tuple[int, ...]) -> 'Address':
Returns:
~.Address: The new address object.
"""
return dataclasses.replace(self,
return dataclasses.replace(
self,
module_path=self.module_path + path,
name=child_name,
parent=self.parent + (self.name,) if self.name else self.parent,
Expand Down Expand Up @@ -278,9 +284,10 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Metadata':
``Address`` object aliases module names to avoid naming collisions in
the file being written.
"""
return dataclasses.replace(self,
return dataclasses.replace(
self,
address=self.address.with_context(collisions=collisions),
)
)


@dataclasses.dataclass(frozen=True)
Expand Down
49 changes: 38 additions & 11 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,18 +330,19 @@ def with_context(self, *,
"""
return dataclasses.replace(
self,
fields=collections.OrderedDict([
fields=collections.OrderedDict(
(k, v.with_context(collisions=collisions))
for k, v in self.fields.items()
]) if not skip_fields else self.fields,
nested_enums=collections.OrderedDict([
) if not skip_fields else self.fields,
nested_enums=collections.OrderedDict(
(k, v.with_context(collisions=collisions))
for k, v in self.nested_enums.items()
]),
nested_messages=collections.OrderedDict([(k, v.with_context(
collisions=collisions,
skip_fields=skip_fields,
)) for k, v in self.nested_messages.items()]),
),
nested_messages=collections.OrderedDict(
(k, v.with_context(
collisions=collisions,
skip_fields=skip_fields,))
for k, v in self.nested_messages.items()),
meta=self.meta.with_context(collisions=collisions),
)

Expand Down Expand Up @@ -455,6 +456,23 @@ class OperationInfo:
response_type: MessageType
metadata_type: MessageType

def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo':
"""Return a derivative of this OperationInfo with the provided context.
This method is used to address naming collisions. The returned
``OperationInfo`` object aliases module names to avoid naming collisions
in the file being written.
"""
return dataclasses.replace(
self,
response_type=self.response_type.with_context(
collisions=collisions
),
metadata_type=self.metadata_type.with_context(
collisions=collisions
),
)


@dataclasses.dataclass(frozen=True)
class RetryInfo:
Expand Down Expand Up @@ -681,8 +699,13 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Method':
``Method`` object aliases module names to avoid naming collisions
in the file being written.
"""
maybe_lro = self.lro.with_context(
collisions=collisions
) if self.lro else None

return dataclasses.replace(
self,
lro=maybe_lro,
input=self.input.with_context(collisions=collisions),
output=self.output.with_context(collisions=collisions),
meta=self.meta.with_context(collisions=collisions),
Expand Down Expand Up @@ -810,9 +833,13 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Service':
"""
return dataclasses.replace(
self,
methods=collections.OrderedDict([
(k, v.with_context(collisions=collisions))
methods=collections.OrderedDict(
(k, v.with_context(
# A methodd's flattened fields create additional names
# that may conflict with module imports.
collisions=collisions | frozenset(v.flattened_fields.keys()))
)
for k, v in self.methods.items()
]),
),
meta=self.meta.with_context(collisions=collisions),
)
105 changes: 105 additions & 0 deletions tests/unit/schema/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest

from google.api import client_pb2
from google.api_core import exceptions
from google.longrunning import operations_pb2
from google.protobuf import descriptor_pb2
Expand Down Expand Up @@ -219,6 +220,110 @@ def test_proto_names_import_collision():
'other_message', 'primitive', 'spam'}


def test_proto_names_import_collision_flattening():
lro_proto = api.Proto.build(make_file_pb2(
name='operations.proto', package='google.longrunning',
messages=(make_message_pb2(name='Operation'),),
), file_to_generate=False, naming=make_naming())

fd = (
make_file_pb2(
name='mollusc.proto',
package='google.animalia.mollusca',
messages=(
make_message_pb2(name='Mollusc',),
make_message_pb2(name='MolluscResponse',),
make_message_pb2(name='MolluscMetadata',),
),
),
make_file_pb2(
name='squid.proto',
package='google.animalia.mollusca',
messages=(
make_message_pb2(
name='IdentifySquidRequest',
fields=(
make_field_pb2(
name='mollusc',
number=1,
type_name='.google.animalia.mollusca.Mollusc'
),
),
),
make_message_pb2(
name='IdentifySquidResponse',
fields=(),
),
),
services=(
descriptor_pb2.ServiceDescriptorProto(
name='SquidIdentificationService',
method=(
descriptor_pb2.MethodDescriptorProto(
name='IdentifyMollusc',
input_type='google.animalia.mollusca.IdentifySquidRequest',
output_type='google.longrunning.Operation',
),
),
),
),
),
)

method_options = fd[1].service[0].method[0].options
# Notice that a signature field collides with the name of an imported module
method_options.Extensions[client_pb2.method_signature].append('mollusc')
method_options.Extensions[operations_pb2.operation_info].MergeFrom(
operations_pb2.OperationInfo(
response_type='google.animalia.mollusca.MolluscResponse',
metadata_type='google.animalia.mollusca.MolluscMetadata',
)
)
api_schema = api.API.build(
fd,
package='google.animalia.mollusca',
prior_protos={
'google/longrunning/operations.proto': lro_proto,
}
)

actual_imports = {
ref_type.ident.python_import
for service in api_schema.services.values()
for method in service.methods.values()
for ref_type in method.ref_types
}

expected_imports = {
imp.Import(
package=('google', 'animalia', 'mollusca', 'types'),
module='mollusc',
alias='gam_mollusc',
),
imp.Import(
package=('google', 'animalia', 'mollusca', 'types'),
module='squid',
),
imp.Import(package=('google', 'api_core'), module='operation',),
}

assert expected_imports == actual_imports

method = (
api_schema
.services['google.animalia.mollusca.SquidIdentificationService']
.methods['IdentifyMollusc']
)

actual_response_import = method.lro.response_type.ident.python_import
expected_response_import = imp.Import(
package=('google', 'animalia', 'mollusca', 'types'),
module='mollusc',
alias='gam_mollusc',
)
assert actual_response_import == expected_response_import


def test_proto_builder_constructor():
sentinel_message = descriptor_pb2.DescriptorProto()
sentinel_enum = descriptor_pb2.EnumDescriptorProto()
Expand Down

0 comments on commit adcec6a

Please sign in to comment.