Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyzx/circuit/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from ..utils import EdgeType, VertexType, FractionLike
from ..graph.base import BaseGraph, VT, ET
from ..symbolic import new_var
from ..symbolic import new_var, Poly

# We need this type variable so that the subclasses of Gate return the correct type for functions like copy()
Tvar = TypeVar('Tvar', bound='Gate')
Expand Down Expand Up @@ -233,7 +233,11 @@ def copy(self: Tvar) -> Tvar:
def to_adjoint(self: Tvar) -> Tvar:
g = self.copy()
if hasattr(g, "phase"):
g.phase = -g.phase
phase = g.phase
if isinstance(phase, Poly):
g.phase = -phase.conjugate()
else:
g.phase = -phase
if hasattr(g, "adjoint"):
g.adjoint = not g.adjoint
return g
Expand Down
8 changes: 4 additions & 4 deletions pyzx/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,8 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> BaseGraph[VT,E
g.merge_vdata = self.merge_vdata # type: ignore
for name in self.var_registry.vars():
g.var_registry.set_type(name, self.var_registry.get_type(name))
mult:int = 1
if adjoint:
mult = -1
def adjoint_phase(phase: FractionLike) -> FractionLike:
return -phase.conjugate()

#g.add_vertices(self.num_vertices())
ty = self.types()
Expand All @@ -496,7 +495,8 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> BaseGraph[VT,E
maxr = self.depth()
vtab = dict()
for v in self.vertices():
i = g.add_vertex(ty[v],phase=mult*ph[v])
phase = ph[v] if not adjoint else adjoint_phase(ph[v])
i = g.add_vertex(ty[v], phase=phase)
if v in qs: g.set_qubit(i,qs[v])
if v in rs:
if adjoint: g.set_row(i, maxr-rs[v])
Expand Down
9 changes: 6 additions & 3 deletions pyzx/graph/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,20 @@ def copy(self, conjugate: bool = False) -> 'Scalar':
Returns:
A copy of the Scalar
"""
def conjugate_phase(phase: FractionLike) -> FractionLike:
return -phase.conjugate()

s = Scalar()
s.power2 = self.power2
s.phase = self.phase if not conjugate else -self.phase
s.phasenodes = copy.copy(self.phasenodes) if not conjugate else [-p for p in self.phasenodes]
s.phase = self.phase if not conjugate else conjugate_phase(self.phase)
s.phasenodes = copy.copy(self.phasenodes) if not conjugate else [conjugate_phase(p) for p in self.phasenodes]
s.floatfactor = self.floatfactor if not conjugate else self.floatfactor.conjugate()
s.is_unknown = self.is_unknown
s.is_zero = self.is_zero
if not conjugate:
s.sum_of_phases = copy.deepcopy(self.sum_of_phases)
else:
s.sum_of_phases = {-phase: coeff for phase, coeff in self.sum_of_phases.items()}
s.sum_of_phases = {conjugate_phase(phase): coeff for phase, coeff in self.sum_of_phases.items()}
return s

def conjugate(self) -> 'Scalar':
Expand Down
10 changes: 9 additions & 1 deletion pyzx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def __pow__(self, other: int) -> 'Poly':
return self * (self ** (other - 1))

def __mod__(self, other: int) -> 'Poly':
return Poly([(c % other, t) for c, t in self.terms if not isinstance(c, complex)])
return Poly([(c if isinstance(c, complex) else c % other, t) for c, t in self.terms])

def __repr__(self) -> str:
return f'Poly({str(self)})'
Expand Down Expand Up @@ -364,6 +364,14 @@ def copy(self) -> 'Poly':
"""Return a shallow copy of the polynomial."""
return Poly([(c, t) for c, t in self.terms])

def conjugate(self) -> 'Poly':
"""Return the complex conjugate of the polynomial."""
def conj_coeff(c):
if isinstance(c, complex):
return c.conjugate()
return c
return Poly([(conj_coeff(c), t) for c, t in self.terms])

def rebind_variables_to_registry(self, new_registry: VarRegistry) -> None:
"""Rebind all variables in this polynomial to the given registry."""
for _, term in self.terms:
Expand Down
35 changes: 34 additions & 1 deletion tests/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
sys.path.append('..')
sys.path.append('.')
mydir = os.path.dirname(__file__)
from fractions import Fraction
from pyzx.generate import cliffordT, cliffords
from pyzx.simplify import clifford_simp
from pyzx.extract import extract_circuit
from pyzx.circuit import Circuit, PhaseGadget
from pyzx.circuit.gates import ParityPhase, FSim
from pyzx.circuit.gates import ParityPhase, FSim, ZPhase, XPhase
from pyzx.symbolic import Poly, Term, Var, new_var
from pyzx.utils import VertexType, EdgeType
from fractions import Fraction

Expand Down Expand Up @@ -252,5 +254,36 @@ def test_fsim_reposition(self):
self.assertEqual(g.control, 0)
self.assertEqual(g.target, 1)

class TestGateAdjointComplexPhase(unittest.TestCase):

def test_zphase_adjoint_real_phase(self):
gate = ZPhase(0, Fraction(1, 4))
adj = gate.to_adjoint()
self.assertEqual(adj.phase, Fraction(-1, 4))

def test_zphase_adjoint_symbolic_real_phase(self):
phase = new_var('theta', is_bool=False)
gate = ZPhase(0, phase)
adj = gate.to_adjoint()
self.assertEqual(adj.phase, -phase)

def test_zphase_adjoint_complex_phase(self):
var = Var('x')
phase = Poly([((3+2j), Term([(var, 1)]))])
gate = ZPhase(0, phase)
adj = gate.to_adjoint()

self.assertEqual(len(adj.phase.terms), 1)
self.assertEqual(adj.phase.terms[0][0], (-3+2j))

def test_xphase_adjoint_complex_phase(self):
var = Var('y')
phase = Poly([((1+1j), Term([(var, 1)]))])
gate = XPhase(0, phase)
adj = gate.to_adjoint()

self.assertEqual(adj.phase.terms[0][0], (-1+1j))


if __name__ == '__main__':
unittest.main()
13 changes: 13 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ def test_adjoint_scalar(self):
g_adj = g.adjoint()
self.assertAlmostEqual(g_adj.scalar.to_number(), scalar.to_number().conjugate())

def test_adjoint_complex_vertex_phase(self):
from pyzx.symbolic import Poly, Term, Var
g = Graph()
var = Var('z')
phase = Poly([((2+3j), Term([(var, 1)]))])
v = g.add_vertex(VertexType.Z, phase=phase)
g_adj = g.adjoint()
adj_v = list(g_adj.vertices())[0]
adj_phase = g_adj.phase(adj_v)

self.assertEqual(len(adj_phase.terms), 1)
self.assertEqual(adj_phase.terms[0][0], (-2+3j))

@unittest.skipUnless(np, "numpy needs to be installed for this to run")
def test_remove_isolated_vertex_preserves_semantics(self):
g = Graph()
Expand Down
33 changes: 32 additions & 1 deletion tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from fractions import Fraction
from pyzx.graph.scalar import Scalar
from pyzx.symbolic import Poly, new_var
from pyzx.symbolic import Poly, Term, Var, new_var

if __name__ == '__main__':
sys.path.append('..')
Expand Down Expand Up @@ -514,6 +514,37 @@ def test_string_representations_with_poly(self):
unicode_repr = scalar.to_unicode()
self.assertIsInstance(unicode_repr, str)

def test_scalar_conjugate_complex_poly_phase(self):
var = Var('a')
phase = Poly([((2+3j), Term([(var, 1)]))])
s = Scalar()
s.phase = phase
conj = s.conjugate()

self.assertEqual(len(conj.phase.terms), 1)
self.assertEqual(conj.phase.terms[0][0], (-2+3j))

def test_scalar_conjugate_phasenodes_complex(self):
var = Var('b')
node_phase = Poly([((1+2j), Term([(var, 1)]))])
s = Scalar()
s.phasenodes = [node_phase]
conj = s.conjugate()

self.assertEqual(len(conj.phasenodes), 1)
self.assertEqual(conj.phasenodes[0].terms[0][0], (-1+2j))

def test_scalar_conjugate_sum_of_phases_complex(self):
var = Var('c')
phase_key = Poly([((4+5j), Term([(var, 1)]))])
s = Scalar()
s.sum_of_phases = {phase_key: 2}
conj = s.conjugate()

self.assertEqual(len(conj.sum_of_phases), 1)
conjugated_key = list(conj.sum_of_phases.keys())[0]
self.assertEqual(conjugated_key.terms[0][0], (-4+5j))
self.assertEqual(conj.sum_of_phases[conjugated_key], 2)


if __name__ == '__main__':
Expand Down
48 changes: 47 additions & 1 deletion tests/test_symbolic_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest
from fractions import Fraction

from pyzx.symbolic import Poly, VarRegistry, new_const, new_var, parse
from pyzx.symbolic import Poly, Term, Var, VarRegistry, new_const, new_var, parse

if __name__ == '__main__':
sys.path.append('..')
Expand Down Expand Up @@ -255,7 +255,53 @@ def test_operator_precedence(self):
self.assertEqual(result, expected)


class TestPolyConjugate(unittest.TestCase):

def test_conjugate_real_coefficients(self):
self.assertEqual(Poly([(3, Term([]))]).conjugate(), Poly([(3, Term([]))]))
self.assertEqual(Poly([(Fraction(1, 2), Term([]))]).conjugate(), Poly([(Fraction(1, 2), Term([]))]))
self.assertEqual(Poly([(2.5, Term([]))]).conjugate(), Poly([(2.5, Term([]))]))

def test_conjugate_complex_coefficients(self):
p = Poly([((3+2j), Term([]))])
result = p.conjugate()
self.assertEqual(len(result.terms), 1)
self.assertEqual(result.terms[0][0], (3-2j))

var = Var('x')
p = Poly([((1+2j), Term([(var, 1)])), ((3-4j), Term([]))])
result = p.conjugate()
coeffs = {t: c for c, t in result.terms}
self.assertEqual(coeffs[Term([(var, 1)])], (1-2j))
self.assertEqual(coeffs[Term([])], (3+4j))

def test_conjugate_pure_imaginary(self):
p = Poly([(2j, Term([]))])
self.assertEqual(p.conjugate().terms[0][0], -2j)


class TestPolyModulo(unittest.TestCase):

def test_modulo_real_coefficients(self):
p = Poly([(5, Term([])), (3, Term([]))])
result = p % 2
coeffs = [c for c, _ in result.terms]
self.assertEqual(coeffs, [1, 1])

def test_modulo_preserves_complex_coefficients(self):
var = Var('x')
p = Poly([((2+3j), Term([(var, 1)])), (5, Term([]))])
result = p % 2
coeffs = {str(t): c for c, t in result.terms}

self.assertEqual(coeffs['x'], (2+3j))
self.assertEqual(coeffs[''], 1)

def test_modulo_pure_complex(self):
p = Poly([((1+2j), Term([]))])
result = p % 2
self.assertEqual(len(result.terms), 1)
self.assertEqual(result.terms[0][0], (1+2j))


if __name__ == '__main__':
Expand Down