Skip to content

Commit

Permalink
Enforce unique serialization keys (#3673)
Browse files Browse the repository at this point in the history
This PR modifies the `SerializableByKey` serialization process to assign a unique suffix to each serialized object. This way keys are guaranteed to be unique within a serialized object - we no longer rely on `_serialization_key_` to provide unique values.

Keys generated this way are consistent between serializations of the same top-level object, provided that each component object has a consistent serialization. However, the key of an object may change between serializations if the top-level object changes. For example:
```
# Pseudocode - each letter represents a SerializableByKey object.
a = [x, y]  # suffix of key(x) is "_1"
b = [y, x]  # suffix of key(x) is "_2"
```

Prerequisite to #3655, which will require updates to remove hashes from serialization keys. Technically a breaking change, but only for unreleased behavior.
  • Loading branch information
95-martin-orion authored Feb 13, 2021
1 parent 1e9323c commit 0677d60
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
59 changes: 29 additions & 30 deletions cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Optional,
overload,
Sequence,
Set,
Type,
Union,
)
Expand Down Expand Up @@ -332,12 +333,11 @@ class SerializableByKey(SupportsJSON):
"""Protocol for objects that can be serialized to a key + context."""

@doc_private
def _serialization_key_(self) -> str:
"""Returns a unique string identifier for this object.
def _serialization_name_(self) -> str:
"""Returns a human-readable string identifier for this object.
This should only return the same value for two objects if they are
equal; otherwise, an error will occur if both are serialized into the
same JSON string.
This identifier need not be globally unique; a unique suffix will be
dynamically generated to create the serialization key for this object.
"""


Expand All @@ -350,8 +350,8 @@ class _SerializedKey(SupportsJSON):
the original `obj` for this type.
"""

def __init__(self, obj: SerializableByKey):
self.key = obj._serialization_key_()
def __init__(self, key: str):
self.key = key

def _json_dict_(self):
return obj_to_dict_helper(self, ['key'])
Expand All @@ -374,8 +374,8 @@ class _SerializedContext(SupportsJSON):
the original `obj` for this type.
"""

def __init__(self, obj: SerializableByKey):
self.key = obj._serialization_key_()
def __init__(self, obj: SerializableByKey, suffix: int):
self.key = f'{obj._serialization_name_()}_{suffix}'
self.obj = obj

def _json_dict_(self):
Expand Down Expand Up @@ -403,12 +403,12 @@ def __init__(self, obj: Any):
# Context information and the wrapped object are stored together in
# `object_dag` to ensure consistent serialization ordering.
self.object_dag = []
context_keys = set()
context = []
for sbk in get_serializable_by_keys(obj):
new_sc = _SerializedContext(sbk)
if new_sc.key not in context_keys:
if sbk not in context:
context.append(sbk)
new_sc = _SerializedContext(sbk, len(context))
self.object_dag.append(new_sc)
context_keys.add(new_sc.key)
self.object_dag += [obj]

def _json_dict_(self):
Expand All @@ -426,7 +426,7 @@ def deserialize_with_context(cls, object_dag, **kwargs):

def has_serializable_by_keys(obj: Any) -> bool:
"""Returns true if obj contains one or more SerializableByKey objects."""
if hasattr(obj, '_serialization_key_'):
if hasattr(obj, '_serialization_name_'):
return True
json_dict = getattr(obj, '_json_dict_', lambda: None)()
if isinstance(json_dict, Dict):
Expand All @@ -447,7 +447,7 @@ def get_serializable_by_keys(obj: Any) -> List[SerializableByKey]:
are nested inside. This is required to ensure
"""
result = []
if hasattr(obj, '_serialization_key_'):
if hasattr(obj, '_serialization_name_'):
result.append(obj)
json_dict = getattr(obj, '_json_dict_', lambda: None)()
if isinstance(json_dict, Dict):
Expand Down Expand Up @@ -504,29 +504,28 @@ def to_json(
to your classes rather than overriding this default.
"""
if has_serializable_by_keys(obj):
obj = _ContextualSerialization(obj)

class ContextualEncoder(cls): # type: ignore
"""An encoder with a context map for concise serialization."""

# This map is populated gradually during serialization. An object
# with components defined in this map will represent those
# These lists populate gradually during serialization. An object
# with components defined in 'context' will represent those
# components using their keys instead of inline definition.
context_map: Dict[str, 'SerializableByKey'] = {}
seen: Set[str] = set()

def default(self, o):
skey = getattr(o, '_serialization_key_', lambda: None)()
if skey in ContextualEncoder.context_map:
if ContextualEncoder.context_map[skey] == o._json_dict_():
return _SerializedKey(o)._json_dict_()
raise ValueError(
'Found different objects with the same serialization key:'
f'\n{ContextualEncoder.context_map[skey]}\n{o}'
)
if skey is not None:
ContextualEncoder.context_map[skey] = o._json_dict_()
return super().default(o)
if not hasattr(o, '_serialization_name_'):
return super().default(o)
for candidate in obj.object_dag[:-1]:
if candidate.obj == o:
if not candidate.key in ContextualEncoder.seen:
ContextualEncoder.seen.add(candidate.key)
return candidate.obj._json_dict_()
else:
return _SerializedKey(candidate.key)._json_dict_()
raise ValueError("Object mutated during serialization.") # coverage: ignore

obj = _ContextualSerialization(obj)
cls = ContextualEncoder

if file_or_fn is None:
Expand Down
14 changes: 8 additions & 6 deletions cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _json_dict_(self):
"data_dict": self.data_dict,
}

def _serialization_key_(self):
def _serialization_name_(self):
return self.name

@classmethod
Expand Down Expand Up @@ -371,7 +371,7 @@ def custom_resolver(name):
final_obj
== """{
"cirq_type": "_SerializedKey",
"key": "sbki_dict"
"key": "sbki_dict_4"
}"""
)

Expand All @@ -382,15 +382,17 @@ def custom_resolver(name):
assert_json_roundtrip_works(dict_sbki, resolvers=test_resolvers)

assert sbki_list != json_serialization._SerializedKey(sbki_list)

# Serialization keys have unique suffixes.
sbki_other_list = SBKImpl('sbki_list', data_list=[sbki_list])
with pytest.raises(ValueError, match='different objects with the same serialization key'):
_ = cirq.to_json(sbki_other_list)
assert_json_roundtrip_works(sbki_other_list, resolvers=test_resolvers)


def test_internal_serializer_types():
sbki = SBKImpl('test_key')
test_key = json_serialization._SerializedKey(sbki)
test_context = json_serialization._SerializedContext(sbki)
key = f'{sbki._serialization_name_()}_1'
test_key = json_serialization._SerializedKey(key)
test_context = json_serialization._SerializedContext(sbki, 1)
test_serialization = json_serialization._ContextualSerialization(sbki)

key_json = test_key._json_dict_()
Expand Down

0 comments on commit 0677d60

Please sign in to comment.