Skip to content

Commit

Permalink
Support gzip serialization (#3662)
Browse files Browse the repository at this point in the history
For comparison with #3655. Observed benchmark results, running locally:
```
nested: 7.016707049 s
size: 299304
flattened: 1.278540381 s
size: 20911
```
  • Loading branch information
95-martin-orion authored Feb 1, 2021
1 parent 97e9a91 commit 3c566d2
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@
qid_shape,
quil,
QuilFormatter,
read_json_gzip,
read_json,
resolve_parameters,
resolve_parameters_once,
Expand Down Expand Up @@ -521,6 +522,7 @@
SupportsQasmWithArgsAndQubits,
SupportsTraceDistanceBound,
SupportsUnitary,
to_json_gzip,
to_json,
obj_to_dict_helper,
trace_distance_bound,
Expand Down
2 changes: 2 additions & 0 deletions cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
DEFAULT_RESOLVERS,
JsonResolver,
json_serializable_dataclass,
to_json_gzip,
read_json_gzip,
to_json,
read_json,
obj_to_dict_helper,
Expand Down
39 changes: 39 additions & 0 deletions cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import dataclasses
import functools
import gzip
import json
import numbers
import pathlib
Expand Down Expand Up @@ -706,3 +707,41 @@ def obj_hook(x):
return json.load(file, object_hook=obj_hook)

return json.load(cast(IO, file_or_fn), object_hook=obj_hook)


def to_json_gzip(
obj: Any,
file_or_fn: Union[None, IO, pathlib.Path, str] = None,
*,
indent: int = 2,
cls: Type[json.JSONEncoder] = CirqEncoder,
) -> Optional[bytes]:
json_str = to_json(obj, indent=indent, cls=cls)
if isinstance(file_or_fn, (str, pathlib.Path)):
with gzip.open(file_or_fn, 'wt', encoding='utf-8') as actually_a_file:
actually_a_file.write(json_str)
return None

gzip_data = gzip.compress(bytes(json_str, encoding='utf-8')) # type: ignore
if file_or_fn is None:
return gzip_data

file_or_fn.write(gzip_data)
return None


def read_json_gzip(
file_or_fn: Union[None, IO, pathlib.Path, str] = None,
*,
gzip_raw: Optional[bytes] = None,
resolvers: Optional[Sequence[JsonResolver]] = None,
):
if (file_or_fn is None) == (gzip_raw is None):
raise ValueError('Must specify ONE of "file_or_fn" or "gzip_raw".')

if gzip_raw is not None:
json_str = gzip.decompress(gzip_raw).decode(encoding='utf-8')
return read_json(json_text=json_str, resolvers=resolvers)

with gzip.open(file_or_fn, 'rt') as json_file: # type: ignore
return read_json(cast(IO, json_file), resolvers=resolvers)
43 changes: 43 additions & 0 deletions cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,32 @@ def test_op_roundtrip_filename(tmpdir):
op2 = cirq.read_json(filename)
assert op1 == op2

gzip_filename = f'{tmpdir}/op.gz'
cirq.to_json_gzip(op1, gzip_filename)
assert os.path.exists(gzip_filename)
op3 = cirq.read_json_gzip(gzip_filename)
assert op1 == op3


def test_op_roundtrip_file_obj(tmpdir):
filename = f'{tmpdir}/op.json'
q = cirq.LineQubit(5)
op1 = cirq.rx(0.123).on(q)
with open(filename, 'w+') as file:
cirq.to_json(op1, file)
assert os.path.exists(filename)
file.seek(0)
op2 = cirq.read_json(file)
assert op1 == op2

gzip_filename = f'{tmpdir}/op.gz'
with open(gzip_filename, 'w+b') as gzip_file:
cirq.to_json_gzip(op1, gzip_file)
assert os.path.exists(gzip_filename)
gzip_file.seek(0)
op3 = cirq.read_json_gzip(gzip_file)
assert op1 == op3


def test_fail_to_resolve():
buffer = io.StringIO()
Expand Down Expand Up @@ -640,6 +666,19 @@ def test_to_from_strings():
cirq.read_json(io.StringIO(), json_text=x_json_text)


def test_to_from_json_gzip():
a, b = cirq.LineQubit.range(2)
test_circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b))
gzip_data = cirq.to_json_gzip(test_circuit)
unzip_circuit = cirq.read_json_gzip(gzip_raw=gzip_data)
assert test_circuit == unzip_circuit

with pytest.raises(ValueError):
_ = cirq.read_json_gzip(io.StringIO(), gzip_raw=gzip_data)
with pytest.raises(ValueError):
_ = cirq.read_json_gzip()


def _eval_repr_data_file(path: pathlib.Path):
return eval(
path.read_text(),
Expand Down Expand Up @@ -736,6 +775,10 @@ def test_pathlib_paths(tmpdir):
cirq.to_json(cirq.X, path)
assert cirq.read_json(path) == cirq.X

gzip_path = pathlib.Path(tmpdir) / 'op.gz'
cirq.to_json_gzip(cirq.X, gzip_path)
assert cirq.read_json_gzip(gzip_path) == cirq.X


def test_json_serializable_dataclass():
@cirq.json_serializable_dataclass
Expand Down
125 changes: 125 additions & 0 deletions dev_tools/profiling/benchmark_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2020 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tool for benchmarking serialization of large circuits.
This tool was originally introduced to enable comparison of the two JSON
serialization protocols (gzip and non-gzip):
https://github.com/quantumlib/Cirq/pull/3662
This is part of the "efficient serialization" effort:
https://github.com/quantumlib/Cirq/issues/3438
Run this benchmark with the following command (make sure to install cirq-dev):
python3 dev_tools/profiling/benchmark_serializers.py \
--num_gates=<int> --nesting_depth=<int> --num_repetitions=<int>
WARNING: runtime increases exponentially with nesting_depth. Values much
higher than nesting_depth=10 are not recommended.
"""

import argparse
import sys
import timeit

import numpy as np

import cirq

_JSON_GZIP = 'json_gzip'
_JSON = 'json'

NUM_QUBITS = 8

SUFFIXES = ['B', 'kB', 'MB', 'GB', 'TB']


def serialize(serializer: str, num_gates: int, nesting_depth: int) -> int:
""""Runs a round-trip of the serializer."""
circuit = cirq.Circuit()
for _ in range(num_gates):
which = np.random.choice(['expz', 'expw', 'exp11'])
if which == 'expw':
q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS))
circuit.append(
cirq.PhasedXPowGate(
phase_exponent=np.random.random(), exponent=np.random.random()
).on(q1)
)
elif which == 'expz':
q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS))
circuit.append(cirq.Z(q1) ** np.random.random())
elif which == 'exp11':
q1 = cirq.GridQubit(0, np.random.randint(NUM_QUBITS - 1))
q2 = cirq.GridQubit(0, q1.col + 1)
circuit.append(cirq.CZ(q1, q2) ** np.random.random())
cs = [circuit]
for _ in range(1, nesting_depth):
fc = cs[-1].freeze()
cs.append(cirq.Circuit(fc.to_op(), fc.to_op()))
test_circuit = cs[-1]

if serializer == _JSON:
json_data = cirq.to_json(test_circuit)
assert json_data is not None
data_size = len(json_data)
cirq.read_json(json_text=json_data)
elif serializer == _JSON_GZIP:
gzip_data = cirq.to_json_gzip(test_circuit)
assert gzip_data is not None
data_size = len(gzip_data)
cirq.read_json_gzip(gzip_raw=gzip_data)
return data_size


def main(
num_gates: int,
nesting_depth: int,
num_repetitions: int,
setup: str = 'from __main__ import serialize',
):
for serializer in [_JSON_GZIP, _JSON]:
print()
print(f'Using serializer "{serializer}":')
command = f'serialize(\'{serializer}\', {num_gates}, {nesting_depth})'
time = timeit.timeit(command, setup, number=num_repetitions)
print(f'Round-trip serializer time: {time / num_repetitions}s')
data_size = float(serialize(serializer, num_gates, nesting_depth))
suffix_idx = 0
while data_size > 1000:
data_size /= 1024
suffix_idx += 1
print(f'Serialized data size: {data_size} {SUFFIXES[suffix_idx]}.')


def parse_arguments(args):
parser = argparse.ArgumentParser('Benchmark a serializer.')
parser.add_argument(
'--num_gates', default=100, type=int, help='Number of gates at the bottom nesting layer.'
)
parser.add_argument(
'--nesting_depth',
default=1,
type=int,
help='Depth of nested subcircuits. Total gate count will be 2^nesting_depth * num_gates.',
)
parser.add_argument(
'--num_repetitions', default=10, type=int, help='Number of times to repeat serialization.'
)
return vars(parser.parse_args(args))


if __name__ == '__main__':
main(**parse_arguments(sys.argv[1:]))
53 changes: 53 additions & 0 deletions dev_tools/profiling/benchmark_serializers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2018 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the simulator benchmarker."""

from dev_tools.profiling import benchmark_serializers


def test_gzip_serializer():
for num_gates in (10, 20):
for nesting_depth in (1, 8):
benchmark_serializers.serialize('json_gzip', num_gates, nesting_depth)


def test_json_serializer():
for num_gates in (10, 20):
for nesting_depth in (1, 8):
benchmark_serializers.serialize('json', num_gates, nesting_depth)


def test_args_have_defaults():
kwargs = benchmark_serializers.parse_arguments([])
for _, v in kwargs.items():
assert v is not None


def test_main_loop():
# Keep test from taking a long time by lowering max qubits.
benchmark_serializers.main(
**benchmark_serializers.parse_arguments({}),
setup='from dev_tools.profiling.benchmark_serializers import serialize',
)


def test_parse_args():
args = ('--num_gates 5 --nesting_depth 8 --num_repetitions 2').split()
kwargs = benchmark_serializers.parse_arguments(args)
assert kwargs == {
'num_gates': 5,
'nesting_depth': 8,
'num_repetitions': 2,
}

0 comments on commit 3c566d2

Please sign in to comment.