Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use concise JSON serialization for FrozenCircuit. #3655

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def moments(self) -> Sequence['cirq.Moment']:
def device(self) -> devices.Device:
pass

@property
@abc.abstractmethod
def name(self) -> Optional[str]:
pass

def freeze(self) -> 'cirq.FrozenCircuit':
"""Creates a FrozenCircuit from this circuit.

Expand All @@ -119,7 +124,9 @@ def freeze(self) -> 'cirq.FrozenCircuit':

if isinstance(self, FrozenCircuit):
return self
return FrozenCircuit(self, strategy=InsertStrategy.EARLIEST, device=self.device)
return FrozenCircuit(
self, strategy=InsertStrategy.EARLIEST, device=self.device, name=self.name
)

def unfreeze(self) -> 'cirq.Circuit':
"""Creates a Circuit from this circuit.
Expand All @@ -128,15 +135,19 @@ def unfreeze(self) -> 'cirq.Circuit':
"""
if isinstance(self, Circuit):
return Circuit.copy(self)
return Circuit(self, strategy=InsertStrategy.EARLIEST, device=self.device)
return Circuit(self, strategy=InsertStrategy.EARLIEST, device=self.device, name=self.name)

def __bool__(self):
return bool(self.moments)

def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return self.moments == other.moments and self.device == other.device
return (
self.moments == other.moments
and self.device == other.device
and self.name == other.name
)

def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
"""See `cirq.protocols.SupportsApproximateEquality`."""
Expand Down Expand Up @@ -1220,11 +1231,13 @@ def save_qasm(
self._to_qasm_output(header, precision, qubit_order).save(file_path)

def _json_dict_(self):
return protocols.obj_to_dict_helper(self, ['moments', 'device'])
return protocols.obj_to_dict_helper(self, ['moments', 'device', 'name'])

@classmethod
def _from_json_dict_(cls, moments, device, **kwargs):
return cls(moments, strategy=InsertStrategy.EARLIEST, device=device)
return cls(
moments, strategy=InsertStrategy.EARLIEST, device=device, name=kwargs.get('name', None)
)


class Circuit(AbstractCircuit):
Expand Down Expand Up @@ -1306,6 +1319,7 @@ def __init__(
*contents: 'cirq.OP_TREE',
strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST,
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
name: Optional[str] = None,
) -> None:
"""Initializes a circuit.

Expand All @@ -1320,9 +1334,12 @@ def __init__(
together. This option does not affect later insertions into the
circuit.
device: Hardware that the circuit should be able to run on.
name: A user-specified key to identify frozen copies of this
circuit in diagrams and serialization.
"""
self._moments: List['cirq.Moment'] = []
self._device = device
self._name = name
self.append(contents, strategy=strategy)

@property
Expand All @@ -1334,6 +1351,10 @@ def device(self, new_device: 'cirq.Device') -> None:
new_device.validate_circuit(self)
self._device = new_device

@property
def name(self) -> Optional[str]:
return self._name

def __copy__(self) -> 'Circuit':
return self.copy()

Expand Down
3 changes: 2 additions & 1 deletion cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3935,14 +3935,15 @@ def with_qubits(self, *qubits):
@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_json_dict(circuit_cls):
q0, q1 = cirq.LineQubit.range(2)
c = circuit_cls(cirq.CNOT(q0, q1))
c = circuit_cls(cirq.CNOT(q0, q1), name='test_circuit')
moments = [cirq.Moment([cirq.CNOT(q0, q1)])]
if circuit_cls == cirq.FrozenCircuit:
moments = tuple(moments)
assert c._json_dict_() == {
'cirq_type': circuit_cls.__name__,
'moments': moments,
'device': cirq.UNCONSTRAINED_DEVICE,
'name': 'test_circuit',
}


Expand Down
19 changes: 14 additions & 5 deletions cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
*contents: 'cirq.OP_TREE',
strategy: 'cirq.InsertStrategy' = InsertStrategy.EARLIEST,
device: 'cirq.Device' = devices.UNCONSTRAINED_DEVICE,
name: Optional[str] = None,
) -> None:
"""Initializes a frozen circuit.

Expand All @@ -62,10 +63,13 @@ def __init__(
from `contents`, this determines how the operations are packed
together.
device: Hardware that the circuit should be able to run on.
name: A user-specified key to identify this circuit in diagrams
and serialization.
"""
base = Circuit(contents, strategy=strategy, device=device)
self._moments = tuple(base.moments)
self._device = base.device
self._name = name

# These variables are memoized when first requested.
self._num_qubits: Optional[int] = None
Expand All @@ -85,13 +89,19 @@ def moments(self) -> Sequence['cirq.Moment']:
def device(self) -> devices.Device:
return self._device

@property
def name(self) -> Optional[str]:
return self._name

def __hash__(self):
return hash((self.moments, self.device))
return hash((self.moments, self.device, self.name))

def serialization_key(self):
# TODO: use this key in serialization and support user-specified keys.
def serialization_key(self) -> str:
key = hash(self) & 0xFFFF_FFFF_FFFF_FFFF
return f'Circuit_0x{key:016x}'
return f'{self.name or "Circuit"}_0x{key:016x}'

def _serialization_key_(self) -> str:
return self.serialization_key()

# Memoized methods for commonly-retrieved properties.

Expand Down Expand Up @@ -146,7 +156,6 @@ def __radd__(self, other) -> 'FrozenCircuit':
# Needed for numpy to handle multiplication by np.int64 correctly.
__array_priority__ = 10000

# TODO: handle multiplication / powers differently?
def __mul__(self, other) -> 'FrozenCircuit':
return (self.unfreeze() * other).freeze()

Expand Down
20 changes: 19 additions & 1 deletion cirq/circuits/frozen_circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def test_freeze_and_unfreeze():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(a), cirq.H(b))
c = cirq.Circuit(cirq.X(a), cirq.H(b), name='testname')

f = c.freeze()
assert f.moments == tuple(c.moments)
Expand All @@ -45,6 +45,24 @@ def test_freeze_and_unfreeze():
assert fcc is not f


def test_names():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(a), cirq.H(b), name='testname')

f = c.freeze()
assert f.name == c.name

ff = c.freeze()
assert ff.name == f.name
# Since the two are identical, this is acceptable.
assert ff.serialization_key() == f.serialization_key()

c.append(cirq.CX(b, a))
fff = c.freeze()
assert fff.name == f.name
assert fff.serialization_key() != f.serialization_key()


def test_immutable():
q = cirq.LineQubit(0)
c = cirq.FrozenCircuit(cirq.X(q), cirq.H(q))
Expand Down
3 changes: 3 additions & 0 deletions cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ def _serialization_key_(self) -> str:
This should only return the same value for two objects if they are
equal; otherwise, an error will occur if both are serialized into the
same JSON string.

Tests will ignore a single hash value ("0x" + 16 hex digits) in each
serialization key to allow for variance between Python sessions.
"""


Expand Down
23 changes: 21 additions & 2 deletions cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import inspect

import datetime
import inspect
import io
import json
import os
import pathlib
import re
import textwrap
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type

Expand Down Expand Up @@ -651,6 +651,20 @@ def _eval_repr_data_file(path: pathlib.Path):
)


def replace_trailing_hash(src: str, dst: str) -> str:
"""Update serialization keys in dst to use hash values from src."""
match = '_Serialized(Context|Key).*\n.*"key".*?0x[0-9a-f]{16}'
re_src = re.finditer(match, src, re.M)
re_dst = re.finditer(match, dst, re.M)
for s, d in zip(re_src, re_dst):
dst = (
dst[: d.start()]
+ src[s.start() : s.end()]
+ dst[d.end() :]
)
return dst


def assert_repr_and_json_test_data_agree(
repr_path: pathlib.Path, json_path: pathlib.Path, inward_only: bool
):
Expand Down Expand Up @@ -683,6 +697,11 @@ def assert_repr_and_json_test_data_agree(

if not inward_only:
json_from_cirq = cirq.to_json(repr_obj)
if json_serialization.has_serializable_by_keys(repr_obj):
# Hash values are not consistent between python sessions.
# Replace each hash in hardcoded JSON file with the value from repr.
json_from_file = replace_trailing_hash(json_from_cirq, json_from_file)

json_from_cirq_obj = json.loads(json_from_cirq)
json_from_file_obj = json.loads(json_from_file)

Expand Down
3 changes: 2 additions & 1 deletion cirq/protocols/json_test_data/CalibrationLayer.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
],
"device": {
"cirq_type": "_UnconstrainedDevice"
}
},
"name": null
},
"args": {
"type": "full",
Expand Down
9 changes: 6 additions & 3 deletions cirq/protocols/json_test_data/Circuit.json
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@
],
"device": {
"cirq_type": "_UnconstrainedDevice"
}
},
"name": "group_measure"
},
{
"cirq_type": "Circuit",
Expand Down Expand Up @@ -170,7 +171,8 @@
],
"device": {
"cirq_type": "_UnconstrainedDevice"
}
},
"name": null
},
{
"cirq_type": "Circuit",
Expand Down Expand Up @@ -200,6 +202,7 @@
],
"device": {
"cirq_type": "_UnconstrainedDevice"
}
},
"name": null
}
]
3 changes: 2 additions & 1 deletion cirq/protocols/json_test_data/Circuit.repr
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
cirq.Moment(
cirq.MeasurementGate(5, '0,1,2,3,4', ()).on(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), cirq.LineQubit(3), cirq.LineQubit(4)),
),
]), cirq.Circuit([
], name="group_measure",
), cirq.Circuit([
cirq.Moment(
cirq.TOFFOLI(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2)),
),
Expand Down
Loading