Skip to content

Commit

Permalink
Merge pull request #5 from BQSKit/state-sample-improvment
Browse files Browse the repository at this point in the history
Adding QFactor-Sample
  • Loading branch information
alonkukl authored Jul 2, 2024
2 parents 1d8b3b7 + d460bbf commit 726b006
Show file tree
Hide file tree
Showing 18 changed files with 1,456 additions and 16 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand All @@ -24,7 +24,7 @@ jobs:
- name: Upgrade test dependencies
run: python -m pip install psutil pytest 'hypothesis[zoneinfo]' qiskit
- name: Install JAX
run: pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
run: pip install jax
- name: Install QFactor JAX
run: pip install .
- name: Run tests
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023,
Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2024,
U.S. Federal Government and the Government of Israel. All rights reserved.

Redistribution and use in source and binary forms, with or without
Expand Down
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# QFactor implementation on GPUs using JAX
`bqskit-qfactor-jax` is a Python package that implements circuit instantiation with [QFactor](https://arxiv.org/abs/2306.08152) on GPUs to accelerate [BQSKit](https://github.com/bqskit/bqskit). It uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as an abstraction layer of the GPUs, seamlessly utilizing JIT compilation and GPU parallelism.
# QFactor and QFactor-Sample implementations on GPUs using JAX
`bqskit-qfactor-jax` is a Python package that implements circuit instantiation using the [QFactor](https://ieeexplore.ieee.org/abstract/document/10313638) and [QFactor-Sample](https://arxiv.org/abs/2405.12866) algorithms on GPUs to accelerate [BQSKit](https://github.com/bqskit/bqskit). It uses [JAX](https://jax.readthedocs.io/en/latest/index.html) as an abstraction layer of the GPUs, seamlessly utilizing JIT compilation and GPU parallelism.

## Installation
`bqskit-qfactor-jax` is available for Python 3.8+ on Linux.
Expand All @@ -21,7 +21,7 @@ pip install bqskit-qfactor-jax
# Running bqskit-qfactor-jax
Please set the environment variable XLA_PYTHON_CLIENT_PREALLOCATE=False when using this package. Also, if you encounter OOM issues consider setting XLA_PYTHON_CLIENT_ALLOCATOR=platform.

Please take a look at the [examples](https://github.com/BQSKit/bqskit-qfactor-jax/tree/main/examples) to see some basic usage.
Please look at the [examples](https://github.com/BQSKit/bqskit-qfactor-jax/tree/main/examples) for basic usage, especially at performance comparison between QFactor and QFactor-Sample.

When using several workers on the same GPU, we recommend using [Nvidia's MPS](https://docs.nvidia.com/deploy/mps/index.html). You may initiate it using the command line
```sh
Expand All @@ -34,7 +34,11 @@ echo quit | nvidia-cuda-mps-control
```

# References
Kukliansky, Alon, et al. "QFactor:A Domain-Specific Optimizer for Quantum Circuit Instantiation." arXiv preprint [arXiv:2306.08152](https://arxiv.org/abs/2306.08152) (2023).
If you are using QFactor-JAX please cite:\
Kukliansky, Alon, et al. "QFactor: A Domain-Specific Optimizer for Quantum Circuit Instantiation." 2023 IEEE International Conference on Quantum Computing and Engineering (QCE). Vol. 1. IEEE, 2023. [Link](https://ieeexplore.ieee.org/abstract/document/10313638).

If you are using QFactor-Sample please cite:\
Kukliansky, Alon, et al. "Leveraging Quantum Machine Learning Generalization to Significantly Speed-up Quantum Compilation" arXiv preprint [arXiv:2405.12866](https://arxiv.org/abs/2405.12866) (2024).

## License
The software in this repository is licensed under a **BSD free software
Expand All @@ -45,5 +49,5 @@ for more information.

## Copyright

Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2023,
Quantum Fast Circuit Optimizer (QFactor) JAX implementation Copyright (c) 2024,
U.S. Federal Government and the Government of Israel. All rights reserved.
87 changes: 87 additions & 0 deletions examples/adder63_10q_block_28.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
OPENQASM 2.0;
include "qelib1.inc";
qreg q[10];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
cx q[7], q[9];
cx q[5], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[5], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[5], q[7];
cx q[4], q[5];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(0.0, 0.0, 0.7853981633974483) q[7];
cx q[4], q[5];
cx q[6], q[7];
u3(0.0, 0.0, 5.497787143782138) q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[6];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[4], q[5];
cx q[6], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[6];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
cx q[2], q[4];
cx q[6], q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, 6.283185307179586) q[6];
cx q[7], q[8];
u3(0.0, 0.0, 0.7853981633974483) q[2];
cx q[3], q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[8];
cx q[0], q[2];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
cx q[7], q[8];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[0];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
cx q[3], q[4];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[8];
cx q[0], q[2];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
cx q[7], q[8];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[0];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
cx q[3], q[4];
cx q[7], q[9];
cx q[0], q[2];
cx q[1], q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, 0.0, -3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
cx q[6], q[7];
u3(0.0, 0.0, 7.0685834705770345) q[9];
cx q[1], q[3];
u3(1.5707963267948966, 2.356194490192345, -3.141592653589793) q[6];
u3(1.5707963267948966, -2.356194490192345, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
cx q[7], q[9];
cx q[1], q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[3];
cx q[7], q[9];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[9];
cx q[7], q[9];
cx q[6], q[7];
u3(0.0, 0.0, 5.497787143782138) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[7];
141 changes: 141 additions & 0 deletions examples/compare_qfactor_sample_to_qfactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from __future__ import annotations

import argparse
import time

from bqskit import enable_logging
from bqskit.compiler import CompilationTask
from bqskit.compiler import Compiler
from bqskit.ir.circuit import Circuit
from bqskit.passes import ToVariablePass

from qfactorjax.qfactor import QFactorJax
from qfactorjax.qfactor_sample_jax import QFactorSampleJax

enable_logging(verbose=True)


parser = argparse.ArgumentParser(
description='Comparing the re-instantiation run time of QFactor-JAX and '
'QFactor-Sample-JAX. Running on adder63_10q_block_28.qasm the '
'difference is X4, and for heisenberg64_10q_block_104.qasm '
'the difference is X10 and QFactor-JAX doesn\'t find a solution. '
'For vqe12_10q_block145.qasm the difference is X34.',
)

parser.add_argument('--input_qasm', type=str, required=True)
parser.add_argument('--multistarts', type=int, default=32)
parser.add_argument('--max_iters', type=int, default=6000)
parser.add_argument('--dist_tol', type=float, default=1e-8)
parser.add_argument('--num_params_coef', type=int, default=1)
parser.add_argument('--exact_amount_of_sample_states', type=int)
parser.add_argument('--overtrain_relative_threshold', type=float, default=0.1)


params = parser.parse_args()


print(params)

file_name = params.input_qasm
dist_tol_requested = params.dist_tol
num_mutlistarts = params.multistarts
max_iters = params.max_iters

num_params_coef = params.num_params_coef

exact_amount_of_sample_states = params.exact_amount_of_sample_states
overtrain_relative_threshold = params.overtrain_relative_threshold


instantiate_options = {
'multistarts': num_mutlistarts,
}


qfactor_gpu_instantiator = QFactorJax(

dist_tol=dist_tol_requested, # Stopping criteria for distance

max_iters=100000, # Maximum number of iterations
min_iters=1, # Minimum number of iterations

# One step plateau detection -
# diff_tol_a + diff_tol_r ∗ |c(i)| <= |c(i)|-|c(i-1)|
diff_tol_a=0.0, # Stopping criteria for distance change
diff_tol_r=1e-10, # Relative criteria for distance change

# Long plateau detection -
# diff_tol_step_r*|c(i-diff_tol_step)| <= |c(i)|-|c(i-diff_tol_step)|
diff_tol_step_r=0.1, # The relative improvement expected
diff_tol_step=200, # The interval in which to check the improvement

# Regularization parameter - [0.0 - 1.0]
# Increase to overcome local minima at the price of longer compute
beta=0.0,
)


qfactor_sample_gpu_instantiator = QFactorSampleJax(

dist_tol=dist_tol_requested, # Stopping criteria for distance

max_iters=max_iters, # Maximum number of iterations
min_iters=6, # Minimum number of iterations

# Regularization parameter - [0.0 - 1.0]
# Increase to overcome local minima at the price of longer compute
beta=0.0,

amount_of_validation_states=2,
# indicates the ratio between the sum of parameters in the circuits to the
# sample size.
diff_tol_r=1e-4,
num_params_coef=num_params_coef,
overtrain_relative_threshold=overtrain_relative_threshold,
exact_amount_of_states_to_train_on=exact_amount_of_sample_states,
)


print(
f'Will use {file_name} {dist_tol_requested = } {num_mutlistarts = }'
f' {num_params_coef = }',
)

orig_10q_block_cir = Circuit.from_file(f'{file_name}')

with Compiler(num_workers=1) as compiler:
task = CompilationTask(orig_10q_block_cir, [ToVariablePass()])
task_id = compiler.submit(task)
orig_10q_block_cir_vu = compiler.result(task_id)


tic = time.perf_counter()
target = orig_10q_block_cir_vu.get_unitary()
time_to_simulate_circ = time.perf_counter() - tic
print(f'Time to simulate was {time_to_simulate_circ}')

tic = time.perf_counter()
orig_10q_block_cir_vu.instantiate(
target, multistarts=num_mutlistarts, method=qfactor_sample_gpu_instantiator,
)
sample_inst_time = time.perf_counter() - tic
inst_sample_dist_from_target = orig_10q_block_cir_vu.get_unitary(
).get_distance_from(target, 1)

print(
f'QFactor-Sample-JAX {sample_inst_time = } '
f'{inst_sample_dist_from_target = }'
f' {num_params_coef = }',
)

tic = time.perf_counter()
orig_10q_block_cir_vu.instantiate(
target, multistarts=num_mutlistarts, method=qfactor_gpu_instantiator,
)
full_inst_time = time.perf_counter() - tic
inst_dist_from_target = orig_10q_block_cir_vu.get_unitary().get_distance_from(
target, 1,
)

print(f'QFactor-JAX {full_inst_time = } {inst_dist_from_target = }')
8 changes: 7 additions & 1 deletion examples/gate_deletion_syth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from timeit import default_timer as timer

from bqskit import Circuit
from bqskit import enable_logging
from bqskit.compiler import Compiler
from bqskit.passes import ForEachBlockPass
from bqskit.passes import QuickPartitioner
Expand All @@ -18,6 +19,9 @@
from qfactorjax.qfactor import QFactorJax


enable_logging()


def run_gate_del_flow_example(
amount_of_workers: int = 10,
) -> tuple[Circuit, Circuit, float]:
Expand Down Expand Up @@ -102,7 +106,9 @@ def run_gate_del_flow_example(

if __name__ == '__main__':

in_circuit, out_circuit, run_time = run_gate_del_flow_example()
in_circuit, out_circuit, run_time = run_gate_del_flow_example(
amount_of_workers=1,
)

print(
f'Partitioning + Synthesis took {run_time}'
Expand Down
57 changes: 57 additions & 0 deletions examples/heisenberg64_10q_block_104.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
OPENQASM 2.0;
include "qelib1.inc";
qreg q[10];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[1];
cx q[0], q[1];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[2];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[3];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[4];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[5];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[6];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[7];
u3(1.5707963267948966, 0.0, 3.141592653589793) q[8];
u3(1.5707963267948966, -3.141592653589793, -3.141592653589793) q[9];
u3(0.0, 0.0, 0.02) q[1];
cx q[0], q[1];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[0];
u3(0.0, 1.406583, -1.406583) q[1];
cx q[1], q[2];
u3(0.0, 0.0, 0.02) q[2];
cx q[1], q[2];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[1];
u3(0.0, 1.406583, -1.406583) q[2];
cx q[2], q[3];
u3(0.0, 0.0, 0.02) q[3];
cx q[2], q[3];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[2];
u3(0.0, 1.406583, -1.406583) q[3];
cx q[3], q[4];
u3(0.0, 0.0, 0.02) q[4];
cx q[3], q[4];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[3];
u3(0.0, 1.406583, -1.406583) q[4];
cx q[4], q[5];
u3(0.0, 0.0, 0.02) q[5];
cx q[4], q[5];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[4];
u3(0.0, 1.406583, -1.406583) q[5];
cx q[5], q[6];
u3(0.0, 0.0, 0.02) q[6];
cx q[5], q[6];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[5];
u3(0.0, 1.406583, -1.406583) q[6];
cx q[6], q[7];
u3(0.0, 0.0, 0.02) q[7];
cx q[6], q[7];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[6];
u3(0.0, 1.406583, -1.406583) q[7];
cx q[7], q[8];
u3(0.0, 0.0, 0.02) q[8];
cx q[7], q[8];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[7];
u3(0.0, 1.406583, -1.406583) q[8];
cx q[8], q[9];
u3(0.0, 0.0, 0.02) q[9];
cx q[8], q[9];
u3(1.5707963267948966, 0.0, 1.5707963267948966) q[8];
u3(0.0, 1.406583, -1.406583) q[9];
Loading

0 comments on commit 726b006

Please sign in to comment.