From 77a2f661efd1629b54b747500e310c0f6a4fd746 Mon Sep 17 00:00:00 2001 From: David Yonge-Mallo Date: Sat, 31 Jan 2026 12:46:10 +0100 Subject: [PATCH 1/2] handle complex phases in .to_adjoint and .conjugate --- pyzx/circuit/gates.py | 8 ++++++-- pyzx/graph/base.py | 8 ++++---- pyzx/graph/scalar.py | 9 ++++++--- pyzx/symbolic.py | 8 ++++++++ tests/test_circuit.py | 35 +++++++++++++++++++++++++++++++++- tests/test_scalar.py | 33 +++++++++++++++++++++++++++++++- tests/test_symbolic_parsing.py | 25 +++++++++++++++++++++++- 7 files changed, 114 insertions(+), 12 deletions(-) diff --git a/pyzx/circuit/gates.py b/pyzx/circuit/gates.py index 87e62723..2d05eb96 100644 --- a/pyzx/circuit/gates.py +++ b/pyzx/circuit/gates.py @@ -26,7 +26,7 @@ from ..utils import EdgeType, VertexType, FractionLike from ..graph.base import BaseGraph, VT, ET -from ..symbolic import new_var +from ..symbolic import new_var, Poly # We need this type variable so that the subclasses of Gate return the correct type for functions like copy() Tvar = TypeVar('Tvar', bound='Gate') @@ -233,7 +233,11 @@ def copy(self: Tvar) -> Tvar: def to_adjoint(self: Tvar) -> Tvar: g = self.copy() if hasattr(g, "phase"): - g.phase = -g.phase + phase = g.phase + if isinstance(phase, Poly): + g.phase = -phase.conjugate() + else: + g.phase = -phase if hasattr(g, "adjoint"): g.adjoint = not g.adjoint return g diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index 894a7dbf..e0f03387 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -484,9 +484,8 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> BaseGraph[VT,E g.merge_vdata = self.merge_vdata # type: ignore for name in self.var_registry.vars(): g.var_registry.set_type(name, self.var_registry.get_type(name)) - mult:int = 1 - if adjoint: - mult = -1 + def adjoint_phase(phase: FractionLike) -> FractionLike: + return -phase.conjugate() #g.add_vertices(self.num_vertices()) ty = self.types() @@ -496,7 +495,8 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> BaseGraph[VT,E maxr = self.depth() vtab = dict() for v in self.vertices(): - i = g.add_vertex(ty[v],phase=mult*ph[v]) + phase = ph[v] if not adjoint else adjoint_phase(ph[v]) + i = g.add_vertex(ty[v], phase=phase) if v in qs: g.set_qubit(i,qs[v]) if v in rs: if adjoint: g.set_row(i, maxr-rs[v]) diff --git a/pyzx/graph/scalar.py b/pyzx/graph/scalar.py index e064c479..5d06f0db 100644 --- a/pyzx/graph/scalar.py +++ b/pyzx/graph/scalar.py @@ -115,17 +115,20 @@ def copy(self, conjugate: bool = False) -> 'Scalar': Returns: A copy of the Scalar """ + def conjugate_phase(phase: FractionLike) -> FractionLike: + return -phase.conjugate() + s = Scalar() s.power2 = self.power2 - s.phase = self.phase if not conjugate else -self.phase - s.phasenodes = copy.copy(self.phasenodes) if not conjugate else [-p for p in self.phasenodes] + s.phase = self.phase if not conjugate else conjugate_phase(self.phase) + s.phasenodes = copy.copy(self.phasenodes) if not conjugate else [conjugate_phase(p) for p in self.phasenodes] s.floatfactor = self.floatfactor if not conjugate else self.floatfactor.conjugate() s.is_unknown = self.is_unknown s.is_zero = self.is_zero if not conjugate: s.sum_of_phases = copy.deepcopy(self.sum_of_phases) else: - s.sum_of_phases = {-phase: coeff for phase, coeff in self.sum_of_phases.items()} + s.sum_of_phases = {conjugate_phase(phase): coeff for phase, coeff in self.sum_of_phases.items()} return s def conjugate(self) -> 'Scalar': diff --git a/pyzx/symbolic.py b/pyzx/symbolic.py index 83ee6879..cc82b17a 100644 --- a/pyzx/symbolic.py +++ b/pyzx/symbolic.py @@ -364,6 +364,14 @@ def copy(self) -> 'Poly': """Return a shallow copy of the polynomial.""" return Poly([(c, t) for c, t in self.terms]) + def conjugate(self) -> 'Poly': + """Return the complex conjugate of the polynomial.""" + def conj_coeff(c): + if isinstance(c, complex): + return c.conjugate() + return c + return Poly([(conj_coeff(c), t) for c, t in self.terms]) + def rebind_variables_to_registry(self, new_registry: VarRegistry) -> None: """Rebind all variables in this polynomial to the given registry.""" for _, term in self.terms: diff --git a/tests/test_circuit.py b/tests/test_circuit.py index b48392ec..c564b60d 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -27,11 +27,13 @@ sys.path.append('..') sys.path.append('.') mydir = os.path.dirname(__file__) +from fractions import Fraction from pyzx.generate import cliffordT, cliffords from pyzx.simplify import clifford_simp from pyzx.extract import extract_circuit from pyzx.circuit import Circuit, PhaseGadget -from pyzx.circuit.gates import ParityPhase, FSim +from pyzx.circuit.gates import ParityPhase, FSim, ZPhase, XPhase +from pyzx.symbolic import Poly, Term, Var, new_var from pyzx.utils import VertexType, EdgeType from fractions import Fraction @@ -252,5 +254,36 @@ def test_fsim_reposition(self): self.assertEqual(g.control, 0) self.assertEqual(g.target, 1) +class TestGateAdjointComplexPhase(unittest.TestCase): + + def test_zphase_adjoint_real_phase(self): + gate = ZPhase(0, Fraction(1, 4)) + adj = gate.to_adjoint() + self.assertEqual(adj.phase, Fraction(-1, 4)) + + def test_zphase_adjoint_symbolic_real_phase(self): + phase = new_var('theta', is_bool=False) + gate = ZPhase(0, phase) + adj = gate.to_adjoint() + self.assertEqual(adj.phase, -phase) + + def test_zphase_adjoint_complex_phase(self): + var = Var('x') + phase = Poly([((3+2j), Term([(var, 1)]))]) + gate = ZPhase(0, phase) + adj = gate.to_adjoint() + + self.assertEqual(len(adj.phase.terms), 1) + self.assertEqual(adj.phase.terms[0][0], (-3+2j)) + + def test_xphase_adjoint_complex_phase(self): + var = Var('y') + phase = Poly([((1+1j), Term([(var, 1)]))]) + gate = XPhase(0, phase) + adj = gate.to_adjoint() + + self.assertEqual(adj.phase.terms[0][0], (-1+1j)) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_scalar.py b/tests/test_scalar.py index 4f8e3bd7..13909a35 100644 --- a/tests/test_scalar.py +++ b/tests/test_scalar.py @@ -19,7 +19,7 @@ import numpy as np from fractions import Fraction from pyzx.graph.scalar import Scalar -from pyzx.symbolic import Poly, new_var +from pyzx.symbolic import Poly, Term, Var, new_var if __name__ == '__main__': sys.path.append('..') @@ -514,6 +514,37 @@ def test_string_representations_with_poly(self): unicode_repr = scalar.to_unicode() self.assertIsInstance(unicode_repr, str) + def test_scalar_conjugate_complex_poly_phase(self): + var = Var('a') + phase = Poly([((2+3j), Term([(var, 1)]))]) + s = Scalar() + s.phase = phase + conj = s.conjugate() + + self.assertEqual(len(conj.phase.terms), 1) + self.assertEqual(conj.phase.terms[0][0], (-2+3j)) + + def test_scalar_conjugate_phasenodes_complex(self): + var = Var('b') + node_phase = Poly([((1+2j), Term([(var, 1)]))]) + s = Scalar() + s.phasenodes = [node_phase] + conj = s.conjugate() + + self.assertEqual(len(conj.phasenodes), 1) + self.assertEqual(conj.phasenodes[0].terms[0][0], (-1+2j)) + + def test_scalar_conjugate_sum_of_phases_complex(self): + var = Var('c') + phase_key = Poly([((4+5j), Term([(var, 1)]))]) + s = Scalar() + s.sum_of_phases = {phase_key: 2} + conj = s.conjugate() + + self.assertEqual(len(conj.sum_of_phases), 1) + conjugated_key = list(conj.sum_of_phases.keys())[0] + self.assertEqual(conjugated_key.terms[0][0], (-4+5j)) + self.assertEqual(conj.sum_of_phases[conjugated_key], 2) if __name__ == '__main__': diff --git a/tests/test_symbolic_parsing.py b/tests/test_symbolic_parsing.py index 892ef07e..ef8488fe 100644 --- a/tests/test_symbolic_parsing.py +++ b/tests/test_symbolic_parsing.py @@ -2,7 +2,7 @@ import unittest from fractions import Fraction -from pyzx.symbolic import Poly, VarRegistry, new_const, new_var, parse +from pyzx.symbolic import Poly, Term, Var, VarRegistry, new_const, new_var, parse if __name__ == '__main__': sys.path.append('..') @@ -255,6 +255,29 @@ def test_operator_precedence(self): self.assertEqual(result, expected) +class TestPolyConjugate(unittest.TestCase): + + def test_conjugate_real_coefficients(self): + self.assertEqual(Poly([(3, Term([]))]).conjugate(), Poly([(3, Term([]))])) + self.assertEqual(Poly([(Fraction(1, 2), Term([]))]).conjugate(), Poly([(Fraction(1, 2), Term([]))])) + self.assertEqual(Poly([(2.5, Term([]))]).conjugate(), Poly([(2.5, Term([]))])) + + def test_conjugate_complex_coefficients(self): + p = Poly([((3+2j), Term([]))]) + result = p.conjugate() + self.assertEqual(len(result.terms), 1) + self.assertEqual(result.terms[0][0], (3-2j)) + + var = Var('x') + p = Poly([((1+2j), Term([(var, 1)])), ((3-4j), Term([]))]) + result = p.conjugate() + coeffs = {t: c for c, t in result.terms} + self.assertEqual(coeffs[Term([(var, 1)])], (1-2j)) + self.assertEqual(coeffs[Term([])], (3+4j)) + + def test_conjugate_pure_imaginary(self): + p = Poly([(2j, Term([]))]) + self.assertEqual(p.conjugate().terms[0][0], -2j) From 21b6f046bd78dfc8ae9eab2fa17d236abe970369 Mon Sep 17 00:00:00 2001 From: David Yonge-Mallo Date: Thu, 5 Feb 2026 16:52:10 +0100 Subject: [PATCH 2/2] fix Poly.__mod__ to preserve complex coefficients --- pyzx/symbolic.py | 2 +- tests/test_graph.py | 13 +++++++++++++ tests/test_symbolic_parsing.py | 23 +++++++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pyzx/symbolic.py b/pyzx/symbolic.py index cc82b17a..6ceb6c5d 100644 --- a/pyzx/symbolic.py +++ b/pyzx/symbolic.py @@ -263,7 +263,7 @@ def __pow__(self, other: int) -> 'Poly': return self * (self ** (other - 1)) def __mod__(self, other: int) -> 'Poly': - return Poly([(c % other, t) for c, t in self.terms if not isinstance(c, complex)]) + return Poly([(c if isinstance(c, complex) else c % other, t) for c, t in self.terms]) def __repr__(self) -> str: return f'Poly({str(self)})' diff --git a/tests/test_graph.py b/tests/test_graph.py index b594325d..11e9d0fa 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -171,6 +171,19 @@ def test_adjoint_scalar(self): g_adj = g.adjoint() self.assertAlmostEqual(g_adj.scalar.to_number(), scalar.to_number().conjugate()) + def test_adjoint_complex_vertex_phase(self): + from pyzx.symbolic import Poly, Term, Var + g = Graph() + var = Var('z') + phase = Poly([((2+3j), Term([(var, 1)]))]) + v = g.add_vertex(VertexType.Z, phase=phase) + g_adj = g.adjoint() + adj_v = list(g_adj.vertices())[0] + adj_phase = g_adj.phase(adj_v) + + self.assertEqual(len(adj_phase.terms), 1) + self.assertEqual(adj_phase.terms[0][0], (-2+3j)) + @unittest.skipUnless(np, "numpy needs to be installed for this to run") def test_remove_isolated_vertex_preserves_semantics(self): g = Graph() diff --git a/tests/test_symbolic_parsing.py b/tests/test_symbolic_parsing.py index ef8488fe..0fd064a7 100644 --- a/tests/test_symbolic_parsing.py +++ b/tests/test_symbolic_parsing.py @@ -280,6 +280,29 @@ def test_conjugate_pure_imaginary(self): self.assertEqual(p.conjugate().terms[0][0], -2j) +class TestPolyModulo(unittest.TestCase): + + def test_modulo_real_coefficients(self): + p = Poly([(5, Term([])), (3, Term([]))]) + result = p % 2 + coeffs = [c for c, _ in result.terms] + self.assertEqual(coeffs, [1, 1]) + + def test_modulo_preserves_complex_coefficients(self): + var = Var('x') + p = Poly([((2+3j), Term([(var, 1)])), (5, Term([]))]) + result = p % 2 + coeffs = {str(t): c for c, t in result.terms} + + self.assertEqual(coeffs['x'], (2+3j)) + self.assertEqual(coeffs[''], 1) + + def test_modulo_pure_complex(self): + p = Poly([((1+2j), Term([]))]) + result = p % 2 + self.assertEqual(len(result.terms), 1) + self.assertEqual(result.terms[0][0], (1+2j)) + if __name__ == '__main__': unittest.main()