diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index be86af8..e9f6060 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] os: [ubuntu-latest] runs-on: ${{ matrix.os }} @@ -24,7 +24,7 @@ jobs: - name: Upgrade test dependencies run: python -m pip install psutil pytest 'hypothesis[zoneinfo]' qiskit - name: Install JAX - run: pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + run: pip install jax - name: Install QFactor JAX run: pip install . - name: Run tests diff --git a/LICENSE b/LICENSE index 1325218..ebdb3b6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023, +Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2024, U.S. Federal Government and the Government of Israel. All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index 72a1190..089cdde 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -# QFactor implementation on GPUs using JAX -`bqskit-qfactor-jax` is a Python package that implements circuit instantiation with [QFactor](https://arxiv.org/abs/2306.08152) on GPUs to accelerate [BQSKit](https://github.com/bqskit/bqskit). It uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as an abstraction layer of the GPUs, seamlessly utilizing JIT compilation and GPU parallelism. +# QFactor and QFactor-Sample implementations on GPUs using JAX +`bqskit-qfactor-jax` is a Python package that implements circuit instantiation using the [QFactor](https://ieeexplore.ieee.org/abstract/document/10313638) and [QFactor-Sample](https://arxiv.org/abs/2405.12866) algorithms on GPUs to accelerate [BQSKit](https://github.com/bqskit/bqskit). It uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as an abstraction layer of the GPUs, seamlessly utilizing JIT compilation and GPU parallelism. ## Installation `bqskit-qfactor-jax` is available for Python 3.8+ on Linux. @@ -21,7 +21,7 @@ pip install bqskit-qfactor-jax # Running bqskit-qfactor-jax Please set the environment variable XLA_PYTHON_CLIENT_PREALLOCATE=False when using this package. Also, if you encounter OOM issues consider setting XLA_PYTHON_CLIENT_ALLOCATOR=platform. -Please take a look at the [examples](https://github.com/BQSKit/bqskit-qfactor-jax/tree/main/examples) to see some basic usage. +Please look at the [examples](https://github.com/BQSKit/bqskit-qfactor-jax/tree/main/examples) for basic usage, especially at performance comparison between QFactor and QFactor-Sample. When using several workers on the same GPU, we recommend using [Nvidia's MPS](https://docs.nvidia.com/deploy/mps/index.html). You may initiate it using the command line ```sh @@ -34,7 +34,11 @@ echo quit | nvidia-cuda-mps-control ``` # References -Kukliansky, Alon, et al. "QFactor:A Domain-Specific Optimizer for Quantum Circuit Instantiation." arXiv preprint [arXiv:2306.08152](https://arxiv.org/abs/2306.08152) (2023). +If you are using QFactor-JAX please cite:\ +Kukliansky, Alon, et al. "QFactor: A Domain-Specific Optimizer for Quantum Circuit Instantiation." 2023 IEEE International Conference on Quantum Computing and Engineering (QCE). Vol. 1. IEEE, 2023. [Link](https://ieeexplore.ieee.org/abstract/document/10313638). + +If you are using QFactor-Sample please cite:\ +Kukliansky, Alon, et al. "Leveraging Quantum Machine Learning Generalization to Significantly Speed-up Quantum Compilation" arXiv preprint [arXiv:2405.12866](https://arxiv.org/abs/2405.12866) (2024). ## License The software in this repository is licensed under a **BSD free software @@ -45,5 +49,5 @@ for more information. ## Copyright -Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023, +Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2024, U.S. Federal Government and the Government of Israel. All rights reserved. diff --git a/examples/adder63_10q_block_28.qasm b/examples/adder63_10q_block_28.qasm new file mode 100644 index 0000000..92c612c --- /dev/null +++ b/examples/adder63_10q_block_28.qasm @@ -0,0 +1,87 @@ +OPENQASM 2.0; +include "qelib1.inc"; +qreg q[10]; +cx q[7], q[9]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[9]; +cx q[7], q[9]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[9]; +cx q[7], q[9]; +cx q[5], q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[5]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +cx q[5], q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[5]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +cx q[5], q[7]; +cx q[4], q[5]; +cx q[7], q[9]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[5]; +u3(0.0, 0.0, 0.7853981633974483) q[7]; +cx q[4], q[5]; +cx q[6], q[7]; +u3(0.0, 0.0, 5.497787143782138) q[9]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[5]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[6]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +cx q[4], q[5]; +cx q[6], q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[6]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +cx q[2], q[4]; +cx q[6], q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[2]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +u3(1.5707963267948966, 0.0, 6.283185307179586) q[6]; +cx q[7], q[8]; +u3(0.0, 0.0, 0.7853981633974483) q[2]; +cx q[3], q[4]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[8]; +cx q[0], q[2]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[3]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +cx q[7], q[8]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[0]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[2]; +cx q[3], q[4]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[8]; +cx q[0], q[2]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[3]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +cx q[7], q[8]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[0]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[2]; +cx q[3], q[4]; +cx q[7], q[9]; +cx q[0], q[2]; +cx q[1], q[3]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +u3(1.5707963267948966, 0.0, -3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[1]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[3]; +cx q[6], q[7]; +u3(0.0, 0.0, 7.0685834705770345) q[9]; +cx q[1], q[3]; +u3(1.5707963267948966, 2.356194490192345, -3.141592653589793) q[6]; +u3(1.5707963267948966, -2.356194490192345, 3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[1]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[3]; +cx q[7], q[9]; +cx q[1], q[3]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[9]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[1]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[3]; +cx q[7], q[9]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[9]; +cx q[7], q[9]; +cx q[6], q[7]; +u3(0.0, 0.0, 5.497787143782138) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[7]; diff --git a/examples/compare_qfactor_sample_to_qfactor.py b/examples/compare_qfactor_sample_to_qfactor.py new file mode 100644 index 0000000..ce9da82 --- /dev/null +++ b/examples/compare_qfactor_sample_to_qfactor.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import argparse +import time + +from bqskit import enable_logging +from bqskit.compiler import CompilationTask +from bqskit.compiler import Compiler +from bqskit.ir.circuit import Circuit +from bqskit.passes import ToVariablePass + +from qfactorjax.qfactor import QFactorJax +from qfactorjax.qfactor_sample_jax import QFactorSampleJax + +enable_logging(verbose=True) + + +parser = argparse.ArgumentParser( + description='Comparing the re-instantiation run time of QFactor-JAX and ' + 'QFactor-Sample-JAX. Running on adder63_10q_block_28.qasm the ' + 'difference is X4, and for heisenberg64_10q_block_104.qasm ' + 'the difference is X10 and QFactor-JAX doesn\'t find a solution. ' + 'For vqe12_10q_block145.qasm the difference is X34.', +) + +parser.add_argument('--input_qasm', type=str, required=True) +parser.add_argument('--multistarts', type=int, default=32) +parser.add_argument('--max_iters', type=int, default=6000) +parser.add_argument('--dist_tol', type=float, default=1e-8) +parser.add_argument('--num_params_coef', type=int, default=1) +parser.add_argument('--exact_amount_of_sample_states', type=int) +parser.add_argument('--overtrain_relative_threshold', type=float, default=0.1) + + +params = parser.parse_args() + + +print(params) + +file_name = params.input_qasm +dist_tol_requested = params.dist_tol +num_mutlistarts = params.multistarts +max_iters = params.max_iters + +num_params_coef = params.num_params_coef + +exact_amount_of_sample_states = params.exact_amount_of_sample_states +overtrain_relative_threshold = params.overtrain_relative_threshold + + +instantiate_options = { + 'multistarts': num_mutlistarts, +} + + +qfactor_gpu_instantiator = QFactorJax( + + dist_tol=dist_tol_requested, # Stopping criteria for distance + + max_iters=100000, # Maximum number of iterations + min_iters=1, # Minimum number of iterations + + # One step plateau detection - + # diff_tol_a + diff_tol_r ∗ |c(i)| <= |c(i)|-|c(i-1)| + diff_tol_a=0.0, # Stopping criteria for distance change + diff_tol_r=1e-10, # Relative criteria for distance change + + # Long plateau detection - + # diff_tol_step_r*|c(i-diff_tol_step)| <= |c(i)|-|c(i-diff_tol_step)| + diff_tol_step_r=0.1, # The relative improvement expected + diff_tol_step=200, # The interval in which to check the improvement + + # Regularization parameter - [0.0 - 1.0] + # Increase to overcome local minima at the price of longer compute + beta=0.0, +) + + +qfactor_sample_gpu_instantiator = QFactorSampleJax( + + dist_tol=dist_tol_requested, # Stopping criteria for distance + + max_iters=max_iters, # Maximum number of iterations + min_iters=6, # Minimum number of iterations + + # Regularization parameter - [0.0 - 1.0] + # Increase to overcome local minima at the price of longer compute + beta=0.0, + + amount_of_validation_states=2, + # indicates the ratio between the sum of parameters in the circuits to the + # sample size. + diff_tol_r=1e-4, + num_params_coef=num_params_coef, + overtrain_relative_threshold=overtrain_relative_threshold, + exact_amount_of_states_to_train_on=exact_amount_of_sample_states, +) + + +print( + f'Will use {file_name} {dist_tol_requested = } {num_mutlistarts = }' + f' {num_params_coef = }', +) + +orig_10q_block_cir = Circuit.from_file(f'{file_name}') + +with Compiler(num_workers=1) as compiler: + task = CompilationTask(orig_10q_block_cir, [ToVariablePass()]) + task_id = compiler.submit(task) + orig_10q_block_cir_vu = compiler.result(task_id) + + +tic = time.perf_counter() +target = orig_10q_block_cir_vu.get_unitary() +time_to_simulate_circ = time.perf_counter() - tic +print(f'Time to simulate was {time_to_simulate_circ}') + +tic = time.perf_counter() +orig_10q_block_cir_vu.instantiate( + target, multistarts=num_mutlistarts, method=qfactor_sample_gpu_instantiator, +) +sample_inst_time = time.perf_counter() - tic +inst_sample_dist_from_target = orig_10q_block_cir_vu.get_unitary( +).get_distance_from(target, 1) + +print( + f'QFactor-Sample-JAX {sample_inst_time = } ' + f'{inst_sample_dist_from_target = }' + f' {num_params_coef = }', +) + +tic = time.perf_counter() +orig_10q_block_cir_vu.instantiate( + target, multistarts=num_mutlistarts, method=qfactor_gpu_instantiator, +) +full_inst_time = time.perf_counter() - tic +inst_dist_from_target = orig_10q_block_cir_vu.get_unitary().get_distance_from( + target, 1, +) + +print(f'QFactor-JAX {full_inst_time = } {inst_dist_from_target = }') diff --git a/examples/gate_deletion_syth.py b/examples/gate_deletion_syth.py index 4d9c785..23af2c6 100644 --- a/examples/gate_deletion_syth.py +++ b/examples/gate_deletion_syth.py @@ -7,6 +7,7 @@ from timeit import default_timer as timer from bqskit import Circuit +from bqskit import enable_logging from bqskit.compiler import Compiler from bqskit.passes import ForEachBlockPass from bqskit.passes import QuickPartitioner @@ -18,6 +19,9 @@ from qfactorjax.qfactor import QFactorJax +enable_logging() + + def run_gate_del_flow_example( amount_of_workers: int = 10, ) -> tuple[Circuit, Circuit, float]: @@ -102,7 +106,9 @@ def run_gate_del_flow_example( if __name__ == '__main__': - in_circuit, out_circuit, run_time = run_gate_del_flow_example() + in_circuit, out_circuit, run_time = run_gate_del_flow_example( + amount_of_workers=1, + ) print( f'Partitioning + Synthesis took {run_time}' diff --git a/examples/heisenberg64_10q_block_104.qasm b/examples/heisenberg64_10q_block_104.qasm new file mode 100644 index 0000000..d878485 --- /dev/null +++ b/examples/heisenberg64_10q_block_104.qasm @@ -0,0 +1,57 @@ +OPENQASM 2.0; +include "qelib1.inc"; +qreg q[10]; +u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[1]; +cx q[0], q[1]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[2]; +u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[3]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[4]; +u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[5]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[6]; +u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[7]; +u3(1.5707963267948966, 0.0, 3.141592653589793) q[8]; +u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[9]; +u3(0.0, 0.0, 0.02) q[1]; +cx q[0], q[1]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[0]; +u3(0.0, 1.406583, -1.406583) q[1]; +cx q[1], q[2]; +u3(0.0, 0.0, 0.02) q[2]; +cx q[1], q[2]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[1]; +u3(0.0, 1.406583, -1.406583) q[2]; +cx q[2], q[3]; +u3(0.0, 0.0, 0.02) q[3]; +cx q[2], q[3]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[2]; +u3(0.0, 1.406583, -1.406583) q[3]; +cx q[3], q[4]; +u3(0.0, 0.0, 0.02) q[4]; +cx q[3], q[4]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[3]; +u3(0.0, 1.406583, -1.406583) q[4]; +cx q[4], q[5]; +u3(0.0, 0.0, 0.02) q[5]; +cx q[4], q[5]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[4]; +u3(0.0, 1.406583, -1.406583) q[5]; +cx q[5], q[6]; +u3(0.0, 0.0, 0.02) q[6]; +cx q[5], q[6]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[5]; +u3(0.0, 1.406583, -1.406583) q[6]; +cx q[6], q[7]; +u3(0.0, 0.0, 0.02) q[7]; +cx q[6], q[7]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[6]; +u3(0.0, 1.406583, -1.406583) q[7]; +cx q[7], q[8]; +u3(0.0, 0.0, 0.02) q[8]; +cx q[7], q[8]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[7]; +u3(0.0, 1.406583, -1.406583) q[8]; +cx q[8], q[9]; +u3(0.0, 0.0, 0.02) q[9]; +cx q[8], q[9]; +u3(1.5707963267948966, 0.0, 1.5707963267948966) q[8]; +u3(0.0, 1.406583, -1.406583) q[9]; diff --git a/examples/toffoli_instantiation_using_sampling.py b/examples/toffoli_instantiation_using_sampling.py new file mode 100644 index 0000000..f886b47 --- /dev/null +++ b/examples/toffoli_instantiation_using_sampling.py @@ -0,0 +1,69 @@ +""" +Numerical Instantiation is the foundation of many of BQSKit's algorithms. + +This is the same instantiation example as in BQSKit using the GPU implementation +of QFactor +""" +from __future__ import annotations + +import numpy as np +from bqskit.ir.circuit import Circuit +from bqskit.ir.gates import VariableUnitaryGate +from bqskit.qis.unitary import UnitaryMatrix + +from qfactorjax.qfactor_sample_jax import QFactorSampleJax + + +def run_toffoli_instantiation(dist_tol_requested: float = 1e-10) -> float: + qfactor_gpu_instantiator = QFactorSampleJax( + + dist_tol=dist_tol_requested, # Stopping criteria for distance + + max_iters=100000, # Maximum number of iterations + min_iters=10, # Minimum number of iterations + + # Regularization parameter - [0.0 - 1.0] + # Increase to overcome local minima at the price of longer compute + beta=0.0, + + amount_of_validation_states=2, + num_params_coef=1, + overtrain_relative_threshold=0.1, + ) + + # We will optimize towards the Toffoli unitary. + toffoli = np.array([ + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1, 0], + ]) + toffoli = UnitaryMatrix(toffoli) + + # Start with the circuit structure + circuit = Circuit(3) + circuit.append_gate(VariableUnitaryGate(2), [1, 2]) + circuit.append_gate(VariableUnitaryGate(2), [0, 2]) + circuit.append_gate(VariableUnitaryGate(2), [1, 2]) + circuit.append_gate(VariableUnitaryGate(2), [0, 2]) + circuit.append_gate(VariableUnitaryGate(2), [0, 1]) + + # Instantiate the circuit template with QFactor + circuit.instantiate( + toffoli, + multistarts=16, + method=qfactor_gpu_instantiator, + ) + + # Calculate and print final distance + dist = circuit.get_unitary().get_distance_from(toffoli, 1) + return dist + + +if __name__ == '__main__': + dist = run_toffoli_instantiation() + print('Final Distance: ', dist) diff --git a/examples/vqe12_10q_block145.qasm b/examples/vqe12_10q_block145.qasm new file mode 100644 index 0000000..e6db212 --- /dev/null +++ b/examples/vqe12_10q_block145.qasm @@ -0,0 +1,60 @@ +OPENQASM 2.0; +include "qelib1.inc"; +qreg q[10]; +cx q[9], q[8]; +u3(1.5707963267948966, 1.5707963267948966, -1.5707963267948966) q[8]; +cx q[8], q[7]; +cx q[7], q[6]; +cx q[6], q[5]; +cx q[5], q[4]; +cx q[4], q[3]; +cx q[3], q[2]; +cx q[2], q[1]; +cx q[1], q[0]; +u3(0.0, 0.0, -0.379188312199599) q[0]; +cx q[1], q[0]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[0]; +cx q[2], q[1]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[1]; +cx q[3], q[2]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[2]; +cx q[4], q[3]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[3]; +cx q[5], q[4]; +cx q[6], q[5]; +u3(1.5707963267948966, 1.5707963267948966, -1.5707963267948966) q[5]; +cx q[7], q[6]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[6]; +cx q[8], q[7]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[7]; +u3(1.5707963267948966, -1.5707963267948966, 1.5707963267948966) q[8]; +cx q[9], q[8]; +cx q[8], q[7]; +cx q[7], q[6]; +cx q[6], q[5]; +cx q[5], q[3]; +cx q[3], q[2]; +cx q[2], q[1]; +cx q[1], q[0]; +u3(0.0, 0.0, 0.379188312199599) q[0]; +cx q[1], q[0]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[0]; +cx q[2], q[1]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[1]; +cx q[3], q[2]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[2]; +cx q[5], q[3]; +u3(1.5707963267948966, -1.5707963267948966, 1.5707963267948966) q[3]; +cx q[6], q[5]; +u3(1.5707963267948966, -1.5707963267948966, 1.5707963267948966) q[5]; +cx q[7], q[6]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[6]; +cx q[8], q[7]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[7]; +cx q[9], q[8]; +u3(0.0, 1.4065829705916304, -1.4065829705916302) q[8]; +cx q[9], q[8]; +cx q[8], q[7]; +cx q[7], q[6]; +cx q[6], q[5]; +cx q[5], q[4]; diff --git a/qfactorjax/__init__.py b/qfactorjax/__init__.py index bac6f89..1e9a2a2 100644 --- a/qfactorjax/__init__.py +++ b/qfactorjax/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations __all__ = [ - 'qfactor_jax', + 'qfactor', 'unitary_acc', 'unitarybuilderjax', 'unitarymatrixjax', diff --git a/qfactorjax/qfactor.py b/qfactorjax/qfactor.py index 5b80f00..acf6509 100644 --- a/qfactorjax/qfactor.py +++ b/qfactorjax/qfactor.py @@ -30,7 +30,7 @@ from bqskit.qis.state.system import StateSystemLike from bqskit.qis.unitary.unitarymatrix import UnitaryLike -_logger = logging.getLogger(__name__) +_logger = logging.getLogger('bqskit.instant.qf-jax') jax.config.update('jax_enable_x64', True) diff --git a/qfactorjax/qfactor_sample_jax.py b/qfactorjax/qfactor_sample_jax.py new file mode 100644 index 0000000..1705fc4 --- /dev/null +++ b/qfactorjax/qfactor_sample_jax.py @@ -0,0 +1,732 @@ +from __future__ import annotations + +import logging +import os +from enum import Enum +from typing import Sequence +from typing import TYPE_CHECKING + +import jax +import jax.numpy as jnp +import numpy as np +import numpy.typing as npt +from bqskit.ir import CircuitLocation +from bqskit.ir import Gate +from bqskit.ir.gates import ConstantGate +from bqskit.ir.gates import U3Gate +from bqskit.ir.gates import VariableUnitaryGate +from bqskit.ir.opt import Instantiater +from bqskit.qis import UnitaryMatrix +from bqskit.qis.state import StateSystem +from bqskit.qis.state import StateVector +from jax import Array +from jax._src.lib import xla_extension as xe +from scipy.stats import unitary_group + +from qfactorjax.qfactor import _apply_padding_and_flatten +from qfactorjax.qfactor import _remove_padding_and_create_matrix +from qfactorjax.singlelegedtensor import LHSTensor +from qfactorjax.singlelegedtensor import RHSTensor +from qfactorjax.singlelegedtensor import SingleLegSideTensor +from qfactorjax.unitary_acc import VariableUnitaryGateAcc +from qfactorjax.unitarymatrixjax import UnitaryMatrixJax + +if TYPE_CHECKING: + from bqskit.ir.circuit import Circuit + from bqskit.qis.state.state import StateLike + from bqskit.qis.state.system import StateSystemLike + from bqskit.qis.unitary.unitarymatrix import UnitaryLike + +_logger = logging.getLogger('bqskit.instant.qf-sample-jax') + + +jax.config.update('jax_enable_x64', True) + + +class TermCondition(Enum): + UNKNOWN = 0 + REACHED_TARGET = 1 + EXCEEDED_MAX_ITER = 2 + PLATEAU_DETECTED = 3 + EXCEEDED_TRAINING_SET_SIZE = 4 + + +class QFactorSampleJax(Instantiater): + + def __init__( + self, + dist_tol: float = 1e-8, + max_iters: int = 100000, + min_iters: int = 2, + beta: float = 0.0, + amount_of_validation_states: int = 2, + num_params_coef: float = 1.0, + overtrain_relative_threshold: float = 0.1, + diff_tol_r: float = 1e-4, # Relative criteria for distance change + plateau_windows_size: int = 6, + exact_amount_of_states_to_train_on: int | None = None, + ): + + if not isinstance(dist_tol, float) or dist_tol > 0.5: + raise TypeError('Invalid distance threshold.') + + if not isinstance(max_iters, int) or max_iters < 0: + raise TypeError('Invalid maximum number of iterations.') + + if not isinstance(min_iters, int) or min_iters < 0: + raise TypeError('Invalid minimum number of iterations.') + + self.dist_tol = dist_tol + self.max_iters = max_iters + self.min_iters = min_iters + + self.beta = beta + self.amount_of_validation_states = amount_of_validation_states + self.num_params_coef = num_params_coef + self.overtrain_ratio = overtrain_relative_threshold + self.diff_tol_r = diff_tol_r + self.plateau_windows_size = plateau_windows_size + self.exact_amount_of_states_to_train_on =\ + exact_amount_of_states_to_train_on + + self.targets_training_set_size_cache: dict[int, int] = {} + + def instantiate( + self, + circuit: Circuit, + target: UnitaryMatrix | StateVector | StateSystem, + x0: npt.NDArray[np.float64], + ) -> npt.NDArray[np.float64]: + + return self.multi_start_instantiate(circuit, target, 1) + + def multi_start_instantiate_inplace( + self, + circuit: Circuit, + target: UnitaryLike | StateLike | StateSystemLike, + num_starts: int, + ) -> None: + """ + Instantiate `circuit` to best implement `target` with multiple starts. + + See :func:`multi_start_instantiate` for more info. + + Notes: + This method is a version of :func:`multi_start_instantiate` + that modifies `circuit` in place rather than returning a copy. + """ + target = self.check_target(target) + params = self.multi_start_instantiate(circuit, target, num_starts) + circuit.set_params(params) + + async def multi_start_instantiate_async( + self, + circuit: Circuit, + target: UnitaryLike | StateLike | StateSystemLike, + num_starts: int, + ) -> npt.NDArray[np.float64]: + + return self.multi_start_instantiate(circuit, target, num_starts) + + def multi_start_instantiate( + self, + circuit: Circuit, + target: UnitaryLike | StateLike | StateSystemLike, + num_starts: int, + ) -> npt.NDArray[np.float64]: + + if len(circuit) == 0: + return np.array([]) + + circuit = circuit.copy() + + gates_list = [] + for op in circuit: + if isinstance(op.gate, VariableUnitaryGate): + new_gate = VariableUnitaryGateAcc( + op.gate.num_qudits, + op.gate.radixes, + ) + gates_list.append(new_gate) + else: + gates_list.append(op.gate) + + target_hash = target.__hash__() + target = UnitaryMatrixJax(self.check_target(target)) + radixes = target.radixes + num_qudits = target.num_qudits + locations = tuple([op.location for op in circuit]) + gates = tuple(gates_list) + biggest_gate_dim = max(g.dim for g in circuit.gate_set) + + if target_hash in self.targets_training_set_size_cache: + initial_amount_of_training_states =\ + self.targets_training_set_size_cache[target_hash] + elif self.exact_amount_of_states_to_train_on is None: + amount_of_params_in_circuit = 0 + for g in gates: + amount_of_params_in_circuit += g.num_params + initial_amount_of_training_states = int( + np.sqrt(amount_of_params_in_circuit) * self.num_params_coef, + ) + else: + initial_amount_of_training_states =\ + self.exact_amount_of_states_to_train_on + + if np.prod(radixes) < initial_amount_of_training_states: + _logger.warning( + f'Requested to use ' + f'{initial_amount_of_training_states} training ' + 'states, while the prod of the radixes is ' + f'{np.prod(radixes)}. This algorithm shines when we have ' + f'much less training states than 2^prod(radixes)', + ) + + amount_of_training_states = initial_amount_of_training_states + + validation_states_kets = self.generate_random_states( + self.amount_of_validation_states, int(np.prod(radixes)), + ) + + generate_untrys_only_once = 'GEN_ONCE' in os.environ + + if generate_untrys_only_once: + untrys = [] + + for gate in gates: + size_of_untry = 2**gate.num_qudits + + if isinstance(gate, VariableUnitaryGateAcc): + pre_padding_untrys = [ + unitary_group.rvs(size_of_untry) for + _ in range(num_starts) + ] + else: + pre_padding_untrys = [ + gate.get_unitary().numpy for + _ in range(num_starts) + ] + + untrys.append([ + _apply_padding_and_flatten( + pre_padd, gate, biggest_gate_dim, + ) for pre_padd in pre_padding_untrys + ]) + + untrys = jnp.array(np.stack(untrys, axis=1)) + + term_condition = None + should_double_the_training_size = True + while term_condition is None: + if not generate_untrys_only_once: + untrys = [] + + for gate in gates: + size_of_untry = 2**gate.num_qudits + + if isinstance(gate, VariableUnitaryGateAcc): + pre_padding_untrys = [ + unitary_group.rvs(size_of_untry) for + _ in range(num_starts) + ] + else: + pre_padding_untrys = [ + gate.get_unitary().numpy for + _ in range(num_starts) + ] + + untrys.append([ + _apply_padding_and_flatten( + pre_padd, gate, biggest_gate_dim, + ) for pre_padd in pre_padding_untrys + ]) + + untrys = jnp.array(np.stack(untrys, axis=1)) + + training_states_kets = self.generate_random_states( + amount_of_training_states, int(np.prod(radixes)), + ) + + results = self.safe_call_jited_vmaped_state_sample_sweep( + target, num_starts, radixes, num_qudits, locations, gates, + validation_states_kets, untrys, training_states_kets, + ) + + final_untrys, training_costs, validation_costs, iteration_counts, \ + plateau_windows = results + + it = iteration_counts[0] + untrys = final_untrys + best_start = jnp.argmin(training_costs) + + if any(training_costs < self.dist_tol): + _logger.debug( + f'Terminated: {it} c1 = {training_costs} <= dist_tol.\n' + f'Best start is {best_start}', + ) + term_condition = TermCondition.REACHED_TARGET + elif it >= self.max_iters: + _logger.debug( + f'Terminated {it}: iteration limit reached. c1 = ' + f'{training_costs}', + ) + term_condition = TermCondition.EXCEEDED_MAX_ITER + elif it > self.min_iters: + val_to_train_diff = validation_costs - training_costs + + if np.all(np.all(plateau_windows, axis=1)): + _logger.debug( + f'Terminated: {it} plateau detected in all' + f' multistarts c1 = {training_costs}', + ) + term_condition = TermCondition.PLATEAU_DETECTED + + elif all( + val_to_train_diff > self.overtrain_ratio * training_costs, + ): + _logger.debug( + f'Terminated: {it} overtraining detected in' + f' all multistarts', + ) + + else: + term_condition = TermCondition.UNKNOWN + else: + term_condition = TermCondition.UNKNOWN + + if term_condition == TermCondition.UNKNOWN: + _logger.error( + f'Terminated with no good reason after {it} iterations ' + f'with c1s {training_costs}.', + ) + + if ( + term_condition == TermCondition.REACHED_TARGET + or term_condition == TermCondition.PLATEAU_DETECTED + or term_condition == TermCondition.EXCEEDED_MAX_ITER + ): + self.targets_training_set_size_cache[target_hash] =\ + amount_of_training_states + + if term_condition is None: + if should_double_the_training_size: + amount_of_training_states *= 2 + else: + amount_of_training_states +=\ + initial_amount_of_training_states + + if amount_of_training_states > np.prod(radixes): + term_condition = TermCondition.EXCEEDED_TRAINING_SET_SIZE + _logger.debug( + 'Stopping as we reached the max number of' + ' training states', + ) + + params: list[Sequence[float]] = [] + for untry, gate in zip(untrys[best_start], gates): + if isinstance(gate, ConstantGate): + params.extend([]) + else: + params.extend( + gate.get_params( + _remove_padding_and_create_matrix(untry, gate), + ), + ) + + return np.array(params) + + def safe_call_jited_vmaped_state_sample_sweep( + self, + target: UnitaryMatrixJax, + num_starts: int, + radixes: tuple[int, ...], + num_qudits: int, + locations: tuple[CircuitLocation, ...], + gates: tuple[Gate, ...], + validation_states_kets: Array, + untrys: Array, + training_states_kets: Array, + ) -> tuple[Array, Array[float], Array[float], Array[int], Array[bool]]: + """We couldn't find a way to check if we are going to allocate more than + the GPU memory, so we created this "safe" function that calls qfactor- + sample and then if OOM exception is caught it recursively calls qfactor- + sample with half the multistarts.""" + + try: + results = _jited_loop_vmaped_state_sample_sweep( + target, num_qudits, radixes, locations, gates, untrys, + self.dist_tol, self.max_iters, self.beta, + num_starts, self.min_iters, self.diff_tol_r, + self.plateau_windows_size, self.overtrain_ratio, + training_states_kets, validation_states_kets, + ) + + except xe.XlaRuntimeError as e: + + if num_starts == 1: + _logger.error( + f'Got a runtime error {e}, while {num_starts = } ,exiting', + ) + raise e + + _logger.debug( + f'Got a runtime error {e} will try re-run with half the starts', + ) + + mid_point = num_starts // 2 + first_half_untrys = untrys[:mid_point] + second_half_untrys = untrys[mid_point:] + + results1 = self.safe_call_jited_vmaped_state_sample_sweep( + target, mid_point, radixes, num_qudits, locations, gates, + validation_states_kets, first_half_untrys, + training_states_kets, + ) + + results2 = self.safe_call_jited_vmaped_state_sample_sweep( + target, num_starts - mid_point, radixes, num_qudits, locations, + gates, validation_states_kets, second_half_untrys, + training_states_kets, + ) + + # TODO: Fix the typing ignore here + results = tuple( + jnp.concatenate((results1[i], results2[i])) + for i in range(5) + ) # type: ignore + + return results + + @staticmethod + def get_method_name() -> str: + """Return the name of this method.""" + return 'qfactor_jax_batched_jit' + + @staticmethod + def can_internally_perform_multistart() -> bool: + """Probes if the instantiater can internally perform multistart.""" + return True + + @staticmethod + def is_capable(circuit: Circuit) -> bool: + """Return true if the circuit can be instantiated.""" + return all( + isinstance( + gate, ( + VariableUnitaryGate, + VariableUnitaryGateAcc, U3Gate, ConstantGate, + ), + ) + for gate in circuit.gate_set + ) + + @staticmethod + def get_violation_report(circuit: Circuit) -> str: + """ + Return a message explaining why `circuit` cannot be instantiated. + + Args: + circuit (Circuit): Generate a report for this circuit. + + Raises: + ValueError: If `circuit` can be instantiated with this + instantiater. + """ + + invalid_gates = { + gate + for gate in circuit.gate_set + if not isinstance( + gate, ( + VariableUnitaryGate, + VariableUnitaryGateAcc, + U3Gate, + ConstantGate, + ), + ) + } + + if len(invalid_gates) == 0: + raise ValueError('Circuit can be instantiated.') + + return ( + 'Cannot instantiate circuit with qfactor because' + ' the following gates are not locally optimizable with jax: %s.' + % ', '.join(str(g) for g in invalid_gates) + ) + + @staticmethod + def generate_random_states( + amount_of_states: int, + size_of_state: int, + ) -> list[npt.NDArray[np.complex128]]: + """ + Generate a list of random state vectors (kets) using random unitary + matrices. + + This function generates a specified number of random quantum state + vectors (kets) by creating random unitary matrices and extracting + their first columns. + + Args: + amount_of_states (int): The number of random states to generate. + size_of_state (int): The dimension of each state vector (ket). + + Returns: + list of ndarrays: A list containing random quantum state vectors. + Each ket is represented as a numpy ndarray of + shape (size_of_state, 1). + """ + states_kets = [] + states_to_add = amount_of_states + while states_to_add > 0: + # We generate a random unitary and take its columns + rand_unitary = unitary_group.rvs(size_of_state) + states_to_add_in_step = min(states_to_add, size_of_state) + for i in range(states_to_add_in_step): + states_kets.append(rand_unitary[:, i:i + 1]) + states_to_add -= states_to_add_in_step + + return states_kets + + +def _loop_vmaped_state_sample_sweep( + target: UnitaryMatrixJax, num_qudits: int, radixes: tuple[int, ...], + locations: tuple[CircuitLocation, ...], + gates: tuple[Gate, ...], untrys: Array, + dist_tol: float, max_iters: int, beta: float, + amount_of_starts: int, min_iters: int, + diff_tol_r: float, plateau_windows_size: int, + overtrain_ratio: float, training_states_kets: Array, + validation_states_kets: Array, +) -> tuple[Array, Array[float], Array[float], Array[int], Array[bool]]: + + # Calculate the bras for the validation and training states + validation_states_bras = jax.vmap( + lambda ket: ket.T.conj(), + )(jnp.array(validation_states_kets)) + + training_states_bras = jax.vmap( + lambda ket: ket.T.conj(), + )(jnp.array(training_states_kets)) + + # Calculate the A and B0 tensor + target_dagger = target.T.conj() + A_train = RHSTensor( + list_of_states=training_states_bras, + num_qudits=num_qudits, radixes=radixes, + ) + A_train.apply_left(target_dagger, range(num_qudits)) + + A_val = RHSTensor( + list_of_states=validation_states_bras, + num_qudits=num_qudits, radixes=radixes, + ) + A_val.apply_left(target_dagger, range(num_qudits)) + + B0_train = LHSTensor( + list_of_states=training_states_kets, + num_qudits=num_qudits, radixes=radixes, + ) + B0_val = LHSTensor( + list_of_states=validation_states_kets, + num_qudits=num_qudits, radixes=radixes, + ) + + # In JAX the body of a while must be a function that accepts and returns + # the same type, and also the check should be a function that accepts it + # and return a boolean + + def should_continue( + loop_var: tuple[ + Array, Array[float], Array[float], Array[int], Array[bool], + ], + ) -> Array[bool]: + _, training_costs, validation_costs, \ + iteration_counts, plateau_windows = loop_var + + any_reached_required_tol = jnp.any( + jax.vmap( + lambda cost: cost <= dist_tol, + )(training_costs), + ) + + reached_max_iteration = iteration_counts[0] > max_iters + above_min_iteration = iteration_counts[0] > min_iters + + val_to_train_diff = validation_costs - training_costs + all_reached_over_training = jnp.all( + val_to_train_diff > overtrain_ratio * training_costs, + ) + + all_reached_plateau = jnp.all( + jnp.all(plateau_windows, axis=1), + ) + + return jnp.logical_not( + jnp.logical_or( + any_reached_required_tol, + jnp.logical_or( + reached_max_iteration, + jnp.logical_and( + above_min_iteration, + jnp.logical_or( + all_reached_over_training, + all_reached_plateau, + ), + ), + ), + ), + ) + + def _while_body_to_be_vmaped( + loop_var: tuple[ + Array, Array[float], Array[float], Array[int], Array[bool], + ], + ) -> tuple[ + Array, Array[float], Array[float], Array[int], Array[bool], + ]: + + untrys, training_cost, validation_cost, iteration_count, \ + plateau_window = loop_var + + untrys_as_matrixes: list[UnitaryMatrixJax] = [] + for gate_index, gate in enumerate(gates): + untrys_as_matrixes.append( + UnitaryMatrixJax( + _remove_padding_and_create_matrix( + untrys[gate_index], gate, + ), gate.radixes, + ), + ) + prev_training_cost = training_cost + + untrys_as_matrixes, training_cost, validation_cost =\ + state_sample_single_sweep( + locations, gates, untrys_as_matrixes, + beta, A_train, A_val, B0_train, B0_val, + ) + + iteration_count += 1 + + have_detected_plateau_in_curr_iter = jnp.abs( + prev_training_cost - training_cost, + ) <= diff_tol_r * jnp.abs(training_cost) + + plateau_window = jnp.concatenate( + ( + jnp.array([have_detected_plateau_in_curr_iter]), + plateau_window[:-1], + ), + ) + + biggest_gate_dim = max(g.dim for g in gates) + final_untrys_padded = jnp.array([ + _apply_padding_and_flatten( + untry.numpy.flatten( + ), gate, biggest_gate_dim, + ) for untry, gate in zip(untrys_as_matrixes, gates) + ]) + + return ( + final_untrys_padded, training_cost, validation_cost, + iteration_count, plateau_window, + ) + + while_body_vmaped = jax.vmap(_while_body_to_be_vmaped) + + initial_loop_var = ( + untrys, + jnp.ones(amount_of_starts), # train_cost + jnp.ones(amount_of_starts), # val_cost + jnp.zeros(amount_of_starts, dtype=int), # iter_count + np.zeros((amount_of_starts, plateau_windows_size), dtype=bool), + ) + + if 'PRINT_LOSS_QFACTOR' in os.environ: + loop_var = initial_loop_var + i = 1 + while should_continue(loop_var): + loop_var = while_body_vmaped(loop_var) + + untrys, training_costs, validation_costs, iteration_counts, \ + plateau_windows = loop_var + _logger.debug(f'TRAINLOSS{i}: {training_costs}') + _logger.debug(f'VALLOSS{i}: {validation_costs}') + i += 1 + r = loop_var + else: + r = jax.lax.while_loop( + should_continue, while_body_vmaped, initial_loop_var, + ) + + final_untrys, training_costs, validation_costs, iteration_counts, \ + plateau_windows = r + + return ( + final_untrys, training_costs, validation_costs, + iteration_counts, plateau_windows, + ) + + +def state_sample_single_sweep( + locations: tuple[CircuitLocation, ...], + gates: tuple[Gate, ...], + untrys: list[UnitaryMatrixJax], beta: float, + A_train: RHSTensor, A_val: RHSTensor, + B0_train: LHSTensor, B0_val: LHSTensor, +) -> tuple[list[UnitaryMatrixJax], float, float]: + + amount_of_gates = len(gates) + B = [B0_train] + for location, utry in zip(locations[:-1], untrys[:-1]): + B.append(B[-1].copy()) + B[-1].apply_right(utry, location) + + # iterate over every gate from right to left and update it + new_untrys_rev: list[UnitaryMatrixJax] = [] + a_train: RHSTensor = A_train.copy() + a_val: RHSTensor = A_val.copy() + for idx in reversed(range(amount_of_gates)): + b = B[idx] + gate = gates[idx] + location = locations[idx] + utry = untrys[idx] + if gate.num_params > 0: + env = SingleLegSideTensor.calc_env(b, a_train, location) + utry = gate.optimize( + env.T, get_untry=True, + prev_untry=utry, beta=beta, + ) + + new_untrys_rev.append(utry) + a_train.apply_left(utry, location) + a_val.apply_left(utry, location) + + untrys = new_untrys_rev[::-1] + + training_cost = calc_cost(A_train, B0_train, a_train) + + validation_cost = calc_cost(A_val, B0_val, a_val) + + return untrys, training_cost, validation_cost + + +def calc_cost(A: RHSTensor, B0: LHSTensor, a: RHSTensor) -> float: + cost = 2 * ( + 1 - jnp.real( + SingleLegSideTensor.calc_env(B0, a, [])[0], + ) / A.single_leg_radix + ) + + return jnp.squeeze(cost) + + +if 'NO_JIT_QFACTOR' in os.environ or 'PRINT_LOSS_QFACTOR' in os.environ: + _jited_loop_vmaped_state_sample_sweep = _loop_vmaped_state_sample_sweep +else: + _jited_loop_vmaped_state_sample_sweep = jax.jit( + _loop_vmaped_state_sample_sweep, static_argnums=( + 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, + ), + ) diff --git a/qfactorjax/singlelegedtensor.py b/qfactorjax/singlelegedtensor.py new file mode 100644 index 0000000..4801ce7 --- /dev/null +++ b/qfactorjax/singlelegedtensor.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from typing import Sequence +from typing import TypeVar + +import jax.numpy as jnp +from bqskit.ir import CircuitLocation +from jax import Array + +from qfactorjax.unitarymatrixjax import UnitaryMatrixJax + + +T = TypeVar('T', bound='SingleLegSideTensor') + + +class SingleLegSideTensor(): + """ + The class represents a tensor that has only a single leg in one of his + sides. + + The single leg will always be index 0 + """ + + def __init__( + self, num_qudits: int, radixes: Sequence[int] = [], + list_of_states: Array = jnp.array([]), + tensor: Array | None = None, single_leg_radix: int | None = None, + ) -> None: + + if len(list_of_states) > 0: + first_state = list_of_states[0] + + assert any(d == 1 for d in first_state.shape) + assert all(s.shape == first_state.shape for s in list_of_states) + + self.num_qudits = num_qudits + self.num_of_legs = num_qudits + 1 + self.radixes = tuple( + radixes if len(radixes) > 0 else [2] * num_qudits, + ) + self.single_leg_radix = len(list_of_states) + self.tensor = jnp.array(list_of_states).reshape( + self.single_leg_radix, *self.radixes, + ) + elif tensor is not None and single_leg_radix is not None: + self.tensor = tensor + self.single_leg_radix = single_leg_radix + self.num_qudits = num_qudits + self.num_of_legs = num_qudits + 1 + self.radixes = tuple( + radixes if len(radixes) > 0 else [2] * num_qudits, + ) + else: + raise RuntimeError("can't create the instance") + + def copy(self: T) -> T: + return self.__class__( + tensor=self.tensor.copy(), + num_qudits=self.num_qudits, + radixes=self.radixes, + single_leg_radix=self.single_leg_radix, + ) + + @staticmethod + def calc_env( + left: LHSTensor, right: RHSTensor, + indexes_to_leave_open: Sequence[int], + ) -> Array: + + # verify correct shape + assert left.radixes == right.radixes + assert left.single_leg_radix == right.single_leg_radix, \ + f'{left.single_leg_radix} != {right.single_leg_radix}' + + left_contraction_indexs = list(range(left.num_qudits + 1)) + right_contraction_indexs = list(range(left.num_qudits + 1)) + + size_of_open = len(indexes_to_leave_open) + + for leg_num, i in enumerate(indexes_to_leave_open): + left_contraction_indexs[i + 1] = size_of_open + \ + leg_num + left.num_of_legs + right_contraction_indexs[i + 1] = leg_num + left.num_of_legs + + env_tensor = jnp.einsum( + left.tensor, left_contraction_indexs, + right.tensor, right_contraction_indexs, + ) + env_mat = env_tensor.reshape((2**size_of_open, -1)) + + return env_mat + + +class RHSTensor(SingleLegSideTensor): + + def apply_left( + self, + utry: UnitaryMatrixJax, + location: Sequence[CircuitLocation], + ) -> None: + """ + Apply the specified unitary on the left of this rhs tensor + + .. + .------. .-----. + 1 -| |---| | + 2 -| gate |---| . + '------' . . + . |- 0 + . . + n ------------| | + '-----' + + """ + + utry_tensor = utry.get_tensor_format() + utry_size = len(utry.radixes) + + utry_tensor_indexs = [ + i + self.num_of_legs for i in range(utry_size) + ] + [1 + l for l in location] + + rhs_tensor_indexes = list(range(self.num_of_legs)) + output_tensor_index = list(range(self.num_of_legs)) + + # matching the leg indexs of the utry + for i, loc in enumerate(location): + rhs_tensor_indexes[1 + loc] = i + self.num_of_legs + + self.tensor = jnp.einsum( + utry_tensor, utry_tensor_indexs, + self.tensor, rhs_tensor_indexes, output_tensor_index, + ) + + +class LHSTensor(SingleLegSideTensor): + + def apply_right( + self, + utry: UnitaryMatrixJax, + location: Sequence[CircuitLocation], + ) -> None: + """ + Apply the specified unitary on the right of this lhs tensor. + + The reuslt looks like this .-----. .------. | |---| + |- 1 | |---| utry |- 2 . . '------' 0-| + . . . . . | |------------ n '-----' + """ + + utry_tensor = utry.get_tensor_format() + utry_size = len(utry.radixes) + + utry_tensor_indexs = [1 + l for l in location] + \ + [i + self.num_of_legs for i in range(utry_size)] + lhs_tensor_indexes = list(range(self.num_of_legs)) + output_tensor_index = list(range(self.num_of_legs)) + + # matching the leg indexs of the utry + for i, loc in enumerate(location): + lhs_tensor_indexes[1 + loc] = i + self.num_of_legs + + self.tensor = jnp.einsum( + utry_tensor, utry_tensor_indexs, + self.tensor, lhs_tensor_indexes, output_tensor_index, + ) diff --git a/qfactorjax/unitarybuilderjax.py b/qfactorjax/unitarybuilderjax.py index 27ab4c2..1c60835 100644 --- a/qfactorjax/unitarybuilderjax.py +++ b/qfactorjax/unitarybuilderjax.py @@ -121,7 +121,7 @@ def get_unitary(self, params: RealVector = []) -> UnitaryMatrixJax: def apply_right( self, - utry: UnitaryMatrix, + utry: UnitaryMatrixJax, location: CircuitLocationLike, inverse: bool = False, check_arguments: bool = True, @@ -140,7 +140,7 @@ def apply_right( '-----' Args: - utry (UnitaryMatrix): The unitary to apply. + utry (UnitaryMatrixJax): The unitary to apply. location (CircuitLocationLike): The qudits to apply the unitary on. @@ -208,7 +208,7 @@ def apply_right( def apply_left( self, - utry: UnitaryMatrix, + utry: UnitaryMatrixJax, location: CircuitLocationLike, inverse: bool = False, check_arguments: bool = True, diff --git a/qfactorjax/unitarymatrixjax.py b/qfactorjax/unitarymatrixjax.py index 50ea402..6afec36 100644 --- a/qfactorjax/unitarymatrixjax.py +++ b/qfactorjax/unitarymatrixjax.py @@ -61,7 +61,7 @@ def __init__( and type(input) is not jax.core.ShapedArray and not _from_tree ): - dim = np.prod(radixes) + dim = np.prod(self._radixes) self._utry = jnp.array(input, dtype=jnp.complex128).reshape( (dim, dim), ) # make sure its a square matrix diff --git a/setup.cfg b/setup.cfg index ac307d3..5477a32 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = bqskit-qfactor-jax -version = 0.0.1 +version = 1.0.0 description = QFactor GPU implementation in BQSKit using JAX long_description = file: README.md long_description_content_type = text/markdown diff --git a/tests/test_examples.py b/tests/test_examples.py index 55be638..dc0d483 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -7,6 +7,8 @@ from examples.gate_deletion_syth import run_gate_del_flow_example from examples.toffoli_instantiation import run_toffoli_instantiation +from examples.toffoli_instantiation_using_sampling import\ + run_toffoli_instantiation as run_toffoli_instantiation_using_sampling def test_toffoli_instantiation() -> None: @@ -14,6 +16,11 @@ def test_toffoli_instantiation() -> None: assert distance <= 1e-10 +def test_toffoli_instantiation_using_sampling() -> None: + distance = run_toffoli_instantiation_using_sampling() + assert distance <= 1e-10 + + def test_gate_del_synth() -> None: if 'AMOUNT_OF_WORKERS' in os.environ: diff --git a/tests/test_single_leg_tensor.py b/tests/test_single_leg_tensor.py new file mode 100644 index 0000000..80c4f7d --- /dev/null +++ b/tests/test_single_leg_tensor.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from random import randint +from random import sample + +import jax +import jax.numpy as jnp +import pytest +from bqskit.ir.gates import CXGate +from bqskit.ir.gates import HGate +from scipy.stats import unitary_group + +from qfactorjax.singlelegedtensor import LHSTensor +from qfactorjax.singlelegedtensor import RHSTensor +from qfactorjax.singlelegedtensor import SingleLegSideTensor +from qfactorjax.unitarymatrixjax import UnitaryMatrixJax + + +jax.config.update('jax_enable_x64', True) + + +@pytest.mark.parametrize( + 'num_qubits, N', + [(randint(2, 7), randint(3, 10)) for _ in range(6)], +) +def test_full_contraction_with_complex_conj(num_qubits: int, N: int) -> None: + random_kets = [unitary_group.rvs(2**num_qubits)[:, :1] for _ in range(N)] + random_bras = [ket.T.conj() for ket in random_kets] + + rhs = RHSTensor(list_of_states=random_bras, num_qudits=num_qubits) + lhs = LHSTensor(list_of_states=random_kets, num_qudits=num_qubits) + + res = SingleLegSideTensor.calc_env(lhs, rhs, []).reshape(1)[0] + + assert jnp.isclose(res, N) + + +@pytest.mark.parametrize( + 'num_qubits, N', + [(randint(2, 7), randint(3, 10)) for _ in range(3)], +) +def test_apply_left_H(num_qubits: int, N: int) -> None: + random_kets = [unitary_group.rvs(2**num_qubits)[:, :1] for _ in range(N)] + random_bras = [ket.T.conj() for ket in random_kets] + + rhs = RHSTensor(list_of_states=random_bras, num_qudits=num_qubits) + orig = rhs.copy() + + H_mat = UnitaryMatrixJax(HGate().get_unitary()) + location = sorted(sample(range(num_qubits), 1)) + rhs.apply_left(H_mat, location) + assert not all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1)) + rhs.apply_left(H_mat, location) + assert all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1)) + + +@pytest.mark.parametrize( + 'num_qubits, N', + [(randint(2, 7), randint(3, 10)) for _ in range(3)], +) +def test_apply_right_H(num_qubits: int, N: int) -> None: + random_kets = [unitary_group.rvs(2**num_qubits)[:, :1] for _ in range(N)] + random_bras = [ket.T.conj() for ket in random_kets] + + rhs = LHSTensor(list_of_states=random_bras, num_qudits=num_qubits) + orig = rhs.copy() + + H_mat = UnitaryMatrixJax(HGate().get_unitary()) + location = sorted(sample(range(num_qubits), 1)) + rhs.apply_right(H_mat, location) + assert not all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1)) + rhs.apply_right(H_mat, location) + assert all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1)) + + +@pytest.mark.parametrize( + 'num_qubits, N', + [(randint(2, 7), randint(3, 10)) for _ in range(3)], +) +def test_apply_left_CX(num_qubits: int, N: int) -> None: + random_kets = [unitary_group.rvs(2**num_qubits)[:, :1] for _ in range(N)] + random_bras = [ket.T.conj() for ket in random_kets] + + rhs = RHSTensor(list_of_states=random_bras, num_qudits=num_qubits) + orig = rhs.copy() + + H_mat = UnitaryMatrixJax(CXGate().get_unitary()) + location = sorted(sample(range(num_qubits), 2)) + rhs.apply_left(H_mat, location) + assert not all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1)) + rhs.apply_left(H_mat, location) + assert all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1)) + + +@pytest.mark.parametrize( + 'num_qubits, N', + [(randint(2, 7), randint(3, 10)) for _ in range(3)], +) +def test_apply_right_CX(num_qubits: int, N: int) -> None: + random_kets = [unitary_group.rvs(2**num_qubits)[:, :1] for _ in range(N)] + random_bras = [ket.T.conj() for ket in random_kets] + + rhs = LHSTensor(list_of_states=random_bras, num_qudits=num_qubits) + orig = rhs.copy() + + H_mat = UnitaryMatrixJax(CXGate().get_unitary()) + location = sorted(sample(range(num_qubits), 2)) + rhs.apply_right(H_mat, location) + assert not all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1)) + rhs.apply_right(H_mat, location) + assert all(jnp.isclose(rhs.tensor, orig.tensor).reshape(-1))