Skip to content

Commit cd0cff6

Browse files
[major] reformatting functionals
1 parent e436a79 commit cd0cff6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+4889
-4391
lines changed

torchquantum/functional/__init__.py

Lines changed: 190 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,197 @@
2222
SOFTWARE.
2323
"""
2424

25-
from .functionals import *
25+
from .gate_wrapper import gate_wrapper, apply_unitary_einsum, apply_unitary_bmm
26+
from .hadamard import hadamard, shadamard, _hadamard_mat_dict, h, ch, sh, chadamard
27+
from .rx import rx, rxx, crx, xx, _rx_mat_dict, rx_matrix, rxx_matrix, crx_matrix
28+
from .ry import ry, ryy, cry, yy, _ry_mat_dict, ry_matrix, ryy_matrix, cry_matrix
29+
from .rz import (
30+
rz,
31+
rzz,
32+
crz,
33+
zz,
34+
zx,
35+
multirz,
36+
rzx,
37+
_rz_mat_dict,
38+
rz_matrix,
39+
rzz_matrix,
40+
crz_matrix,
41+
multirz_matrix,
42+
rzx_matrix,
43+
)
44+
from .phase_shift import phaseshift_matrix, phaseshift, p, _phaseshift_mat_dict
45+
from .rot import rot, crot, rot_matrix, crot_matrix, _rot_mat_dict
46+
from .reset import reset
47+
from .xx_min_yy import xxminyy, xxminyy_matrix, _xxminyy_mat_dict
48+
from .xx_plus_yy import xxplusyy, xxplusyy_matrix, _xxplusyy_mat_dict
49+
from .u1 import u1, cu1, u1_matrix, cu1_matrix, _u1_mat_dict, cp, cr, cphase
50+
from .u2 import u2, cu2, u2_matrix, cu2_matrix, _u2_mat_dict
51+
from .u3 import u, u3, cu3, cu, cu_matrix, u3_matrix, cu3_matrix, _u3_mat_dict
52+
from .qubit_unitary import (
53+
qubitunitary,
54+
qubitunitaryfast,
55+
qubitunitarystrict,
56+
qubitunitary_matrix,
57+
qubitunitaryfast_matrix,
58+
qubitunitarystrict_matrix,
59+
_qubitunitary_mat_dict,
60+
)
61+
from .single_excitation import (
62+
singleexcitation,
63+
singleexcitation_matrix,
64+
_singleexcitation_mat_dict,
65+
)
66+
from .paulix import (
67+
_x_mat_dict,
68+
multicnot_matrix,
69+
multixcnot_matrix,
70+
paulix,
71+
cnot,
72+
multicnot,
73+
multixcnot,
74+
x,
75+
c3x,
76+
c4x,
77+
dcx,
78+
toffoli,
79+
ccnot,
80+
ccx,
81+
cx,
82+
rccx,
83+
rc3x,
84+
)
85+
from .pauliy import _y_mat_dict, pauliy, cy, y
86+
from .pauliz import _z_mat_dict, pauliz, cz, ccz, z
87+
from .qft import _qft_mat_dict, qft, qft_matrix
88+
from .r import _r_mat_dict, r, r_matrix
89+
from .global_phase import _globalphase_mat_dict, globalphase, globalphase_matrix
90+
from .sx import _sx_mat_dict, sx, c3sx, sxdg, csx
91+
from .i import _i_mat_dict, i
92+
from .s import _s_mat_dict, s, sdg, cs, csdg
93+
from .t import _t_mat_dict, t, tdg
94+
from .swap import _swap_mat_dict, swap, sswap, iswap, cswap
95+
from .ecr import _ecr_mat_dict, ecr, echoedcrossresonance
2696

27-
from .func_mat_exp import *
28-
from .func_controlled_unitary import *
97+
mat_dict = {
98+
**_hadamard_mat_dict,
99+
**_rx_mat_dict,
100+
**_ry_mat_dict,
101+
**_rz_mat_dict,
102+
**_phaseshift_mat_dict,
103+
**_rot_mat_dict,
104+
**_xxminyy_mat_dict,
105+
**_xxplusyy_mat_dict,
106+
**_u1_mat_dict,
107+
**_u2_mat_dict,
108+
**_u3_mat_dict,
109+
**_qubitunitary_mat_dict,
110+
**_x_mat_dict,
111+
**_y_mat_dict,
112+
**_z_mat_dict,
113+
**_singleexcitation_mat_dict,
114+
**_qft_mat_dict,
115+
**_r_mat_dict,
116+
**_globalphase_mat_dict,
117+
**_sx_mat_dict,
118+
**_i_mat_dict,
119+
**_s_mat_dict,
120+
**_t_mat_dict,
121+
**_swap_mat_dict,
122+
**_ecr_mat_dict,
123+
}
124+
125+
func_name_dict = {
126+
"hadamard": hadamard,
127+
"h": h,
128+
"sh": shadamard,
129+
"paulix": paulix,
130+
"pauliy": pauliy,
131+
"pauliz": pauliz,
132+
"i": i,
133+
"s": s,
134+
"t": t,
135+
"sx": sx,
136+
"cnot": cnot,
137+
"cz": cz,
138+
"cy": cy,
139+
"rx": rx,
140+
"ry": ry,
141+
"rz": rz,
142+
"rxx": rxx,
143+
"xx": xx,
144+
"ryy": ryy,
145+
"yy": yy,
146+
"rzz": rzz,
147+
"zz": zz,
148+
"rzx": rzx,
149+
"zx": zx,
150+
"swap": swap,
151+
"sswap": sswap,
152+
"cswap": cswap,
153+
"toffoli": toffoli,
154+
"phaseshift": phaseshift,
155+
"p": p,
156+
"cp": cp,
157+
"rot": rot,
158+
"multirz": multirz,
159+
"crx": crx,
160+
"cry": cry,
161+
"crz": crz,
162+
"crot": crot,
163+
"u1": u1,
164+
"u2": u2,
165+
"u3": u3,
166+
"u": u,
167+
"cu1": cu1,
168+
"cphase": cphase,
169+
"cr": cr,
170+
"cu2": cu2,
171+
"cu3": cu3,
172+
"cu": cu,
173+
"qubitunitary": qubitunitary,
174+
"qubitunitaryfast": qubitunitaryfast,
175+
"qubitunitarystrict": qubitunitarystrict,
176+
"multicnot": multicnot,
177+
"multixcnot": multixcnot,
178+
"x": x,
179+
"y": y,
180+
"z": z,
181+
"cx": cx,
182+
"ccnot": ccnot,
183+
"ccx": ccx,
184+
"reset": reset,
185+
"singleexcitation": singleexcitation,
186+
"ecr": ecr,
187+
"echoedcrossresonance": echoedcrossresonance,
188+
"qft": qft,
189+
"sdg": sdg,
190+
"tdg": tdg,
191+
"sxdg": sxdg,
192+
"ch": ch,
193+
"ccz": ccz,
194+
"iswap": iswap,
195+
"cs": cs,
196+
"csdg": csdg,
197+
"csx": csx,
198+
"chadamard": chadamard,
199+
"ccz": ccz,
200+
"dcx": dcx,
201+
"xxminyy": xxminyy,
202+
"xxplusyy": xxplusyy,
203+
"c3x": c3x,
204+
"r": r,
205+
"globalphase": globalphase,
206+
"c3sx": c3sx,
207+
"rccx": rccx,
208+
"rc3x": rc3x,
209+
"c4x": c4x,
210+
}
211+
212+
from .func_mat_exp import matrix_exp
213+
from .func_controlled_unitary import controlled_unitary
29214

30215
func_name_dict_collect = {
31-
'matrix_exp': matrix_exp,
32-
'controlled_unitary': controlled_unitary,
216+
"matrix_exp": matrix_exp,
217+
"controlled_unitary": controlled_unitary,
33218
}

torchquantum/functional/ecr.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import functools
2+
import torch
3+
import numpy as np
4+
5+
from typing import Callable, Union, Optional, List, Dict, TYPE_CHECKING
6+
from ..macro import C_DTYPE, F_DTYPE, ABC, ABC_ARRAY, INV_SQRT2
7+
from ..util.utils import pauli_eigs, diag
8+
from torchpack.utils.logging import logger
9+
from torchquantum.util import normalize_statevector
10+
11+
from .gate_wrapper import gate_wrapper, apply_unitary_einsum, apply_unitary_bmm
12+
13+
if TYPE_CHECKING:
14+
from torchquantum.device import QuantumDevice
15+
else:
16+
QuantumDevice = None
17+
18+
_ecr_mat_dict = {
19+
"ecr": INV_SQRT2
20+
* torch.tensor(
21+
[[0, 0, 1, 1j], [0, 0, 1j, 1], [1, -1j, 0, 0], [-1j, 1, 0, 0]], dtype=C_DTYPE
22+
),
23+
}
24+
25+
26+
def ecr(
27+
q_device,
28+
wires,
29+
params=None,
30+
n_wires=None,
31+
static=False,
32+
parent_graph=None,
33+
inverse=False,
34+
comp_method="bmm",
35+
):
36+
"""Perform the echoed cross-resonance gate.
37+
https://qiskit.org/documentation/stubs/qiskit.circuit.library.ECRGate.html
38+
39+
Args:
40+
q_device (tq.QuantumDevice): The QuantumDevice.
41+
wires (Union[List[int], int]): Which qubit(s) to apply the gate.
42+
params (torch.Tensor, optional): Parameters (if any) of the gate.
43+
Default to None.
44+
n_wires (int, optional): Number of qubits the gate is applied to.
45+
Default to None.
46+
static (bool, optional): Whether use static mode computation.
47+
Default to False.
48+
parent_graph (tq.QuantumGraph, optional): Parent QuantumGraph of
49+
current operation. Default to None.
50+
inverse (bool, optional): Whether inverse the gate. Default to False.
51+
comp_method (bool, optional): Use 'bmm' or 'einsum' method to perform
52+
matrix vector multiplication. Default to 'bmm'.
53+
54+
Returns:
55+
None.
56+
57+
"""
58+
name = "ecr"
59+
mat = _ecr_mat_dict[name]
60+
gate_wrapper(
61+
name=name,
62+
mat=mat,
63+
method=comp_method,
64+
q_device=q_device,
65+
wires=wires,
66+
params=params,
67+
n_wires=n_wires,
68+
static=static,
69+
parent_graph=parent_graph,
70+
inverse=inverse,
71+
)
72+
73+
74+
echoedcrossresonance = ecr

torchquantum/functional/func_controlled_unitary.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424

2525
import numpy as np
2626
import torch
27-
from torchquantum.functional.functionals import gate_wrapper
27+
from torchquantum.functional.gate_wrapper import gate_wrapper
2828
from torchquantum.macro import *
2929

30+
3031
def controlled_unitary(
3132
qdev,
3233
c_wires,
@@ -41,7 +42,7 @@ def controlled_unitary(
4142
t_wires: can be a list of list of wires, multiple sets
4243
[[1,2], [3,4]]
4344
params: the parameters of the unitary
44-
45+
4546
Returns:
4647
None.
4748
@@ -113,19 +114,19 @@ def controlled_unitary(
113114
unitary[-d_controlled_u:, -d_controlled_u:] = controlled_u
114115

115116
# return cls(
116-
# has_params=True,
117-
# trainable=trainable,
118-
# init_params=unitary,
119-
# n_wires=n_wires,
120-
# wires=wires,
117+
# has_params=True,
118+
# trainable=trainable,
119+
# init_params=unitary,
120+
# n_wires=n_wires,
121+
# wires=wires,
121122
# )
122123

123-
name = 'qubitunitaryfast'
124+
name = "qubitunitaryfast"
124125
unitary = unitary.to(qdev.device)
125126
gate_wrapper(
126127
name=name,
127128
mat=unitary,
128-
method='bmm',
129+
method="bmm",
129130
q_device=qdev,
130131
wires=wires,
131132
n_wires=n_wires,

torchquantum/functional/func_mat_exp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
"""
2424

2525
import torch
26-
from .functionals import gate_wrapper
26+
from .gate_wrapper import gate_wrapper
2727
from typing import Union
2828
import numpy as np
2929
import torchquantum.functional as tqf
3030

3131
__all__ = ["matrix_exp"]
3232

33+
3334
def matrix_exp(
3435
qdev,
3536
wires,
@@ -62,7 +63,7 @@ def matrix_exp(
6263

6364
mat = torch.matrix_exp(params)
6465

65-
name = 'qubitunitaryfast'
66+
name = "qubitunitaryfast"
6667

6768
tqf.qubitunitaryfast(
6869
q_device=qdev,

0 commit comments

Comments
 (0)