diff --git a/cirq/__init__.py b/cirq/__init__.py index 6c5ea14f2f4..3844c3c4008 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -493,6 +493,7 @@ qid_shape, quil, QuilFormatter, + read_json_gzip, read_json, resolve_parameters, resolve_parameters_once, @@ -521,6 +522,7 @@ SupportsQasmWithArgsAndQubits, SupportsTraceDistanceBound, SupportsUnitary, + to_json_gzip, to_json, obj_to_dict_helper, trace_distance_bound, diff --git a/cirq/protocols/__init__.py b/cirq/protocols/__init__.py index 7c1fdf7e87d..d2101e9d507 100644 --- a/cirq/protocols/__init__.py +++ b/cirq/protocols/__init__.py @@ -79,6 +79,8 @@ DEFAULT_RESOLVERS, JsonResolver, json_serializable_dataclass, + to_json_gzip, + read_json_gzip, to_json, read_json, obj_to_dict_helper, diff --git a/cirq/protocols/json_serialization.py b/cirq/protocols/json_serialization.py index 564604f4b88..b5c66f515fe 100644 --- a/cirq/protocols/json_serialization.py +++ b/cirq/protocols/json_serialization.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses import functools +import gzip import json import numbers import pathlib @@ -706,3 +707,41 @@ def obj_hook(x): return json.load(file, object_hook=obj_hook) return json.load(cast(IO, file_or_fn), object_hook=obj_hook) + + +def to_json_gzip( + obj: Any, + file_or_fn: Union[None, IO, pathlib.Path, str] = None, + *, + indent: int = 2, + cls: Type[json.JSONEncoder] = CirqEncoder, +) -> Optional[bytes]: + json_str = to_json(obj, indent=indent, cls=cls) + if isinstance(file_or_fn, (str, pathlib.Path)): + with gzip.open(file_or_fn, 'wt', encoding='utf-8') as actually_a_file: + actually_a_file.write(json_str) + return None + + gzip_data = gzip.compress(bytes(json_str, encoding='utf-8')) # type: ignore + if file_or_fn is None: + return gzip_data + + file_or_fn.write(gzip_data) + return None + + +def read_json_gzip( + file_or_fn: Union[None, IO, pathlib.Path, str] = None, + *, + gzip_raw: Optional[bytes] = None, + resolvers: Optional[Sequence[JsonResolver]] = None, +): + if (file_or_fn is None) == (gzip_raw is None): + raise ValueError('Must specify ONE of "file_or_fn" or "gzip_raw".') + + if gzip_raw is not None: + json_str = gzip.decompress(gzip_raw).decode(encoding='utf-8') + return read_json(json_text=json_str, resolvers=resolvers) + + with gzip.open(file_or_fn, 'rt') as json_file: # type: ignore + return read_json(cast(IO, json_file), resolvers=resolvers) diff --git a/cirq/protocols/json_serialization_test.py b/cirq/protocols/json_serialization_test.py index 9b51326e1f0..4c825dcf0c6 100644 --- a/cirq/protocols/json_serialization_test.py +++ b/cirq/protocols/json_serialization_test.py @@ -91,6 +91,32 @@ def test_op_roundtrip_filename(tmpdir): op2 = cirq.read_json(filename) assert op1 == op2 + gzip_filename = f'{tmpdir}/op.gz' + cirq.to_json_gzip(op1, gzip_filename) + assert os.path.exists(gzip_filename) + op3 = cirq.read_json_gzip(gzip_filename) + assert op1 == op3 + + +def test_op_roundtrip_file_obj(tmpdir): + filename = f'{tmpdir}/op.json' + q = cirq.LineQubit(5) + op1 = cirq.rx(0.123).on(q) + with open(filename, 'w+') as file: + cirq.to_json(op1, file) + assert os.path.exists(filename) + file.seek(0) + op2 = cirq.read_json(file) + assert op1 == op2 + + gzip_filename = f'{tmpdir}/op.gz' + with open(gzip_filename, 'w+b') as gzip_file: + cirq.to_json_gzip(op1, gzip_file) + assert os.path.exists(gzip_filename) + gzip_file.seek(0) + op3 = cirq.read_json_gzip(gzip_file) + assert op1 == op3 + def test_fail_to_resolve(): buffer = io.StringIO() @@ -640,6 +666,19 @@ def test_to_from_strings(): cirq.read_json(io.StringIO(), json_text=x_json_text) +def test_to_from_json_gzip(): + a, b = cirq.LineQubit.range(2) + test_circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b)) + gzip_data = cirq.to_json_gzip(test_circuit) + unzip_circuit = cirq.read_json_gzip(gzip_raw=gzip_data) + assert test_circuit == unzip_circuit + + with pytest.raises(ValueError): + _ = cirq.read_json_gzip(io.StringIO(), gzip_raw=gzip_data) + with pytest.raises(ValueError): + _ = cirq.read_json_gzip() + + def _eval_repr_data_file(path: pathlib.Path): return eval( path.read_text(), @@ -736,6 +775,10 @@ def test_pathlib_paths(tmpdir): cirq.to_json(cirq.X, path) assert cirq.read_json(path) == cirq.X + gzip_path = pathlib.Path(tmpdir) / 'op.gz' + cirq.to_json_gzip(cirq.X, gzip_path) + assert cirq.read_json_gzip(gzip_path) == cirq.X + def test_json_serializable_dataclass(): @cirq.json_serializable_dataclass diff --git a/dev_tools/profiling/benchmark_serializers.py b/dev_tools/profiling/benchmark_serializers.py new file mode 100644 index 00000000000..2dcc3d5fa8d --- /dev/null +++ b/dev_tools/profiling/benchmark_serializers.py @@ -0,0 +1,125 @@ +# Copyright 2020 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool for benchmarking serialization of large circuits. + +This tool was originally introduced to enable comparison of the two JSON +serialization protocols (gzip and non-gzip): +https://github.com/quantumlib/Cirq/pull/3662 + +This is part of the "efficient serialization" effort: +https://github.com/quantumlib/Cirq/issues/3438 + +Run this benchmark with the following command (make sure to install cirq-dev): + + python3 dev_tools/profiling/benchmark_serializers.py \ + --num_gates= --nesting_depth= --num_repetitions= + +WARNING: runtime increases exponentially with nesting_depth. Values much +higher than nesting_depth=10 are not recommended. +""" + +import argparse +import sys +import timeit + +import numpy as np + +import cirq + +_JSON_GZIP = 'json_gzip' +_JSON = 'json' + +NUM_QUBITS = 8 + +SUFFIXES = ['B', 'kB', 'MB', 'GB', 'TB'] + + +def serialize(serializer: str, num_gates: int, nesting_depth: int) -> int: + """"Runs a round-trip of the serializer.""" + circuit = cirq.Circuit() + for _ in range(num_gates): + which = np.random.choice(['expz', 'expw', 'exp11']) + if which == 'expw': + q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS)) + circuit.append( + cirq.PhasedXPowGate( + phase_exponent=np.random.random(), exponent=np.random.random() + ).on(q1) + ) + elif which == 'expz': + q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS)) + circuit.append(cirq.Z(q1) ** np.random.random()) + elif which == 'exp11': + q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS - 1)) + q2 = cirq.GridQubit(0, q1.col + 1) + circuit.append(cirq.CZ(q1, q2) ** np.random.random()) + cs = [circuit] + for _ in range(1, nesting_depth): + fc = cs[-1].freeze() + cs.append(cirq.Circuit(fc.to_op(), fc.to_op())) + test_circuit = cs[-1] + + if serializer == _JSON: + json_data = cirq.to_json(test_circuit) + assert json_data is not None + data_size = len(json_data) + cirq.read_json(json_text=json_data) + elif serializer == _JSON_GZIP: + gzip_data = cirq.to_json_gzip(test_circuit) + assert gzip_data is not None + data_size = len(gzip_data) + cirq.read_json_gzip(gzip_raw=gzip_data) + return data_size + + +def main( + num_gates: int, + nesting_depth: int, + num_repetitions: int, + setup: str = 'from __main__ import serialize', +): + for serializer in [_JSON_GZIP, _JSON]: + print() + print(f'Using serializer "{serializer}":') + command = f'serialize(\'{serializer}\', {num_gates}, {nesting_depth})' + time = timeit.timeit(command, setup, number=num_repetitions) + print(f'Round-trip serializer time: {time / num_repetitions}s') + data_size = float(serialize(serializer, num_gates, nesting_depth)) + suffix_idx = 0 + while data_size > 1000: + data_size /= 1024 + suffix_idx += 1 + print(f'Serialized data size: {data_size} {SUFFIXES[suffix_idx]}.') + + +def parse_arguments(args): + parser = argparse.ArgumentParser('Benchmark a serializer.') + parser.add_argument( + '--num_gates', default=100, type=int, help='Number of gates at the bottom nesting layer.' + ) + parser.add_argument( + '--nesting_depth', + default=1, + type=int, + help='Depth of nested subcircuits. Total gate count will be 2^nesting_depth * num_gates.', + ) + parser.add_argument( + '--num_repetitions', default=10, type=int, help='Number of times to repeat serialization.' + ) + return vars(parser.parse_args(args)) + + +if __name__ == '__main__': + main(**parse_arguments(sys.argv[1:])) diff --git a/dev_tools/profiling/benchmark_serializers_test.py b/dev_tools/profiling/benchmark_serializers_test.py new file mode 100644 index 00000000000..b7a47f99e69 --- /dev/null +++ b/dev_tools/profiling/benchmark_serializers_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the simulator benchmarker.""" + +from dev_tools.profiling import benchmark_serializers + + +def test_gzip_serializer(): + for num_gates in (10, 20): + for nesting_depth in (1, 8): + benchmark_serializers.serialize('json_gzip', num_gates, nesting_depth) + + +def test_json_serializer(): + for num_gates in (10, 20): + for nesting_depth in (1, 8): + benchmark_serializers.serialize('json', num_gates, nesting_depth) + + +def test_args_have_defaults(): + kwargs = benchmark_serializers.parse_arguments([]) + for _, v in kwargs.items(): + assert v is not None + + +def test_main_loop(): + # Keep test from taking a long time by lowering max qubits. + benchmark_serializers.main( + **benchmark_serializers.parse_arguments({}), + setup='from dev_tools.profiling.benchmark_serializers import serialize', + ) + + +def test_parse_args(): + args = ('--num_gates 5 --nesting_depth 8 --num_repetitions 2').split() + kwargs = benchmark_serializers.parse_arguments(args) + assert kwargs == { + 'num_gates': 5, + 'nesting_depth': 8, + 'num_repetitions': 2, + }