From cccbf4003d7f30f972d2f56c80e5cfd8c35768f4 Mon Sep 17 00:00:00 2001 From: David Yonge-Mallo Date: Thu, 29 Jan 2026 07:45:23 +0100 Subject: [PATCH] Add support for H-boxes with arbitrary complex labels --- pyzx/drawing.py | 28 +++- pyzx/graph/base.py | 4 +- pyzx/graph/jsonparser.py | 8 +- pyzx/rewrite_rules/bialgebra_rule.py | 6 +- pyzx/rewrite_rules/copy_rule.py | 6 +- pyzx/rewrite_rules/fuse_hboxes_rule.py | 6 +- pyzx/rewrite_rules/had_edge_hbox_rule.py | 4 +- pyzx/rewrite_rules/hbox_cancel_rule.py | 8 +- pyzx/rewrite_rules/hbox_not_remove_rule.py | 4 +- pyzx/rewrite_rules/hpivot_rule.py | 11 +- pyzx/rewrite_rules/par_hbox_rule.py | 4 +- pyzx/rewrite_rules/push_pauli_rule.py | 4 +- pyzx/rewrite_rules/zero_hbox_rule.py | 13 +- pyzx/tensor.py | 14 +- pyzx/tikz.py | 44 ++++-- pyzx/utils.py | 32 +++++ tests/{test_hbox_cancel.py => test_hbox.py} | 140 +++++++++++++++++++- tests/test_jsonparser.py | 62 ++++++++- tests/test_tensor.py | 75 ++++++++++- 19 files changed, 427 insertions(+), 46 deletions(-) rename tests/{test_hbox_cancel.py => test_hbox.py} (65%) diff --git a/pyzx/drawing.py b/pyzx/drawing.py index b236801a..10faa690 100644 --- a/pyzx/drawing.py +++ b/pyzx/drawing.py @@ -38,7 +38,7 @@ lines: Any = None -from .utils import settings, get_mode, phase_to_s, EdgeType, VertexType, FloatInt, get_z_box_label +from .utils import settings, get_mode, phase_to_s, VertexType, FloatInt, get_z_box_label, get_h_box_label, hbox_has_complex_label from .graph.base import BaseGraph, VT, ET from .circuit import Circuit @@ -241,7 +241,16 @@ def draw_matplotlib( t = g.type(v) a = g.phase(v) a_offset = 0.5 - phase_str = phase_to_s(a, t) + # Handle H-boxes with complex labels. + if t == VertexType.H_BOX and hbox_has_complex_label(g, v): + label = get_h_box_label(g, v) + # Standard Hadamard (-1) is displayed as empty. + if cmath.isclose(label, -1): + phase_str = '' + else: + phase_str = str(label) + else: + phase_str = phase_to_s(a, t) if t == VertexType.Z: ax.add_patch(patches.Circle(p, 0.2, facecolor='#ccffcc', edgecolor='black', zorder=1)) @@ -332,12 +341,25 @@ def graph_json(g: BaseGraph[VT, ET], vdata: Optional[List[str]]=None, pauli_web: Optional[PauliWeb[VT,ET]]=None) -> str: + def get_phase_str(v): + """Get phase string for a vertex, handling complex labels.""" + ty = g.type(v) + if ty == VertexType.Z_BOX: + return str(get_z_box_label(g, v)) + elif ty == VertexType.H_BOX and hbox_has_complex_label(g, v): + label = get_h_box_label(g, v) + # Standard Hadamard (-1) is displayed as empty. + if cmath.isclose(label, -1): + return '' + return str(label) + return phase_to_s(g.phase(v), ty, poly_with_pi=True) + nodes = [{'name': str(v), 'x': float(coords[v][0]), 'y': float(coords[v][1]), 'z': float(coords[v][2]), 't': g.type(v), - 'phase': phase_to_s(g.phase(v), g.type(v), poly_with_pi=True) if g.type(v) != VertexType.Z_BOX else str(get_z_box_label(g, v)), + 'phase': get_phase_str(v), 'ground': g.is_ground(v), 'vdata': [(key, g.vdata(v, key)) for key in vdata or [] if g.vdata(v, key, None) is not None], diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index 894a7dbf..5bb673f0 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -340,7 +340,9 @@ def add_vertex(self, ) -> VT: """Add a single vertex to the graph and return its index. The optional parameters allow you to respectively set - the type, qubit index, row index and phase of the vertex.""" + the type, qubit index, row index and phase of the vertex. + For H-boxes and Z-boxes with complex labels, use set_h_box_label + or set_z_box_label after creating the vertex.""" if index is not None: self.add_vertex_indexed(index) v = index diff --git a/pyzx/graph/jsonparser.py b/pyzx/graph/jsonparser.py index d9f77374..17b72039 100644 --- a/pyzx/graph/jsonparser.py +++ b/pyzx/graph/jsonparser.py @@ -294,6 +294,12 @@ def graph_to_dict_old(g: BaseGraph[VT,ET], include_scalar: bool=True) -> Dict[st elif t==VertexType.H_BOX: node_vs[name]["data"]["type"] = "hadamard" node_vs[name]["data"]["is_edge"] = "false" + # Only export label if set; legacy H-boxes use phase field instead. + hbox_label = g.vdata(v, 'label', None) + if hbox_label is not None: + if isinstance(hbox_label, Fraction): + hbox_label = phase_to_s(hbox_label, limit_denominator=False) + node_vs[name]["annotation"]["label"] = hbox_label elif t==VertexType.W_INPUT: node_vs[name]["data"]["type"] = "W_input" elif t==VertexType.W_OUTPUT: @@ -301,7 +307,7 @@ def graph_to_dict_old(g: BaseGraph[VT,ET], include_scalar: bool=True) -> Dict[st elif t==VertexType.Z_BOX: node_vs[name]["data"]["type"] = "Z_box" zbox_label = g.vdata(v, 'label', 1) - if type(zbox_label) == Fraction: + if isinstance(zbox_label, Fraction): zbox_label = phase_to_s(zbox_label, limit_denominator=False) node_vs[name]["annotation"]["label"] = zbox_label else: raise Exception("Unkown vertex type "+ str(t)) diff --git a/pyzx/rewrite_rules/bialgebra_rule.py b/pyzx/rewrite_rules/bialgebra_rule.py index 5eaa52b6..b28384ae 100644 --- a/pyzx/rewrite_rules/bialgebra_rule.py +++ b/pyzx/rewrite_rules/bialgebra_rule.py @@ -37,7 +37,8 @@ from collections import defaultdict from typing import Callable, Optional, List, Tuple, Dict -from pyzx.utils import EdgeType, VertexType, is_pauli +from pyzx.utils import (EdgeType, VertexType, is_pauli, + hbox_has_complex_label, get_h_box_label, set_h_box_label) from pyzx.graph.base import BaseGraph, VT, ET, upair RewriteOutputType = Tuple[Dict[Tuple[VT,VT],List[int]], List[VT], List[ET], bool] @@ -94,6 +95,9 @@ def unsafe_bialgebra(g: BaseGraph[VT,ET], v1: VT, v2: VT ) -> bool: r = 0.4*g.row(other_vertex) + 0.6*g.row(v[i]) newv = g.add_vertex(g.type(v[j]), qubit=q, row=r) g.set_phase(newv, g.phase(v[j])) + # Copy complex label if H-box has one. + if g.type(v[j]) == VertexType.H_BOX and hbox_has_complex_label(g, v[j]): + set_h_box_label(g, newv, get_h_box_label(g, v[j])) new_verts[i].append(newv) if other_vertex == v[j]: q = 0.4*g.qubit(v[i]) + 0.6*g.qubit(other_vertex) diff --git a/pyzx/rewrite_rules/copy_rule.py b/pyzx/rewrite_rules/copy_rule.py index eaaa5a3b..d6ceb823 100644 --- a/pyzx/rewrite_rules/copy_rule.py +++ b/pyzx/rewrite_rules/copy_rule.py @@ -30,7 +30,7 @@ 'unsafe_copy',] from typing import Optional -from pyzx.utils import EdgeType, VertexType, toggle_vertex, vertex_is_zx +from pyzx.utils import EdgeType, VertexType, toggle_vertex, vertex_is_zx, is_standard_hbox from pyzx.graph.base import BaseGraph, ET, VT @@ -95,6 +95,10 @@ def check_copy_h( et = g.edge_type(g.edge(v, w)) if tw == VertexType.H_BOX: + # Only apply to standard H-boxes (label=-1 or phase=1). + # Non-standard H-boxes have different scalar factors. + if not is_standard_hbox(g, w): + return None # X pi/0 can always copy through H-box # But if v is Z, then it can only copy if the phase is 1 if et == EdgeType.HADAMARD: diff --git a/pyzx/rewrite_rules/fuse_hboxes_rule.py b/pyzx/rewrite_rules/fuse_hboxes_rule.py index 2d20264b..2681e1ac 100644 --- a/pyzx/rewrite_rules/fuse_hboxes_rule.py +++ b/pyzx/rewrite_rules/fuse_hboxes_rule.py @@ -29,7 +29,7 @@ from typing import Dict, List, Tuple, Set -from pyzx.utils import EdgeType, VertexType +from pyzx.utils import EdgeType, VertexType, is_standard_hbox from pyzx.graph.base import BaseGraph, ET, VT, upair @@ -50,7 +50,7 @@ def check_connected_hboxes(g: BaseGraph[VT ,ET], v: VT, w: VT) -> bool: if g.edge_type(e) != EdgeType.HADAMARD: return False v1 ,v2 = g.edge_st(e) if ty[v1] != VertexType.H_BOX or ty[v2] != VertexType.H_BOX: return False - if g.phase(v1) != 1 and g.phase(v2) != 1: return False + if not is_standard_hbox(g, v1) and not is_standard_hbox(g, v2): return False m.add(e) return True @@ -67,7 +67,7 @@ def unsafe_fuse_hboxes(g: BaseGraph[VT ,ET], v1: VT, v2: VT) -> bool: rem_verts = [] etab: Dict[Tuple[VT ,VT], List[int]] = {} - if g.phase(v2) != 1: # at most one of v1 and v2 has a phase different from 1 + if not is_standard_hbox(g, v2): # Ensure v2 is the standard one. v1, v2 = v2, v1 rem_verts.append(v2) g.scalar.add_power(1) diff --git a/pyzx/rewrite_rules/had_edge_hbox_rule.py b/pyzx/rewrite_rules/had_edge_hbox_rule.py index c4dc9fa0..91c740d7 100644 --- a/pyzx/rewrite_rules/had_edge_hbox_rule.py +++ b/pyzx/rewrite_rules/had_edge_hbox_rule.py @@ -28,14 +28,14 @@ 'had_edge_to_hbox', 'unsafe_had_edge_to_hbox'] -from pyzx.utils import EdgeType, VertexType +from pyzx.utils import EdgeType, VertexType, is_standard_hbox from pyzx.graph.base import BaseGraph, ET, VT from pyzx.rewrite_rules.euler_rule import check_hadamard_edge def check_hadamard(g: BaseGraph[VT ,ET], v: VT) -> bool: """Returns whether the vertex v in graph g is a Hadamard gate.""" if g.type(v) != VertexType.H_BOX: return False - if g.phase(v) != 1: return False + if not is_standard_hbox(g, v): return False if g.vertex_degree(v) != 2: return False return True diff --git a/pyzx/rewrite_rules/hbox_cancel_rule.py b/pyzx/rewrite_rules/hbox_cancel_rule.py index 9cda1bbb..4b91f3f9 100644 --- a/pyzx/rewrite_rules/hbox_cancel_rule.py +++ b/pyzx/rewrite_rules/hbox_cancel_rule.py @@ -31,7 +31,7 @@ 'unsafe_hbox_cancel'] from pyzx.graph.base import BaseGraph, VT, ET -from pyzx.utils import EdgeType, VertexType +from pyzx.utils import EdgeType, VertexType, is_standard_hbox def check_hbox_cancel(g: BaseGraph[VT, ET], v: VT) -> bool: @@ -40,7 +40,7 @@ def check_hbox_cancel(g: BaseGraph[VT, ET], v: VT) -> bool: return False if g.type(v) != VertexType.H_BOX: return False - if g.phase(v) != 1: + if not is_standard_hbox(g, v): return False if g.vertex_degree(v) != 2: return False @@ -51,7 +51,7 @@ def check_hbox_cancel(g: BaseGraph[VT, ET], v: VT) -> bool: e = g.edge(v, n) if g.edge_type(e) == EdgeType.SIMPLE: if (g.type(n) == VertexType.H_BOX and - g.phase(n) == 1 and + is_standard_hbox(g, n) and g.vertex_degree(n) == 2): return True @@ -79,7 +79,7 @@ def unsafe_hbox_cancel(g: BaseGraph[VT, ET], v: VT) -> bool: e = g.edge(v, n) if g.edge_type(e) == EdgeType.SIMPLE: if (g.type(n) == VertexType.H_BOX and - g.phase(n) == 1 and + is_standard_hbox(g, n) and g.vertex_degree(n) == 2): # Found two adjacent H-boxes connected by a simple edge. diff --git a/pyzx/rewrite_rules/hbox_not_remove_rule.py b/pyzx/rewrite_rules/hbox_not_remove_rule.py index b791959f..1aa7173b 100644 --- a/pyzx/rewrite_rules/hbox_not_remove_rule.py +++ b/pyzx/rewrite_rules/hbox_not_remove_rule.py @@ -30,7 +30,7 @@ from typing import Dict, List, Tuple -from pyzx.utils import EdgeType, VertexType +from pyzx.utils import EdgeType, VertexType, is_standard_hbox from pyzx.graph.base import BaseGraph, ET, VT, upair @@ -63,7 +63,7 @@ def check_hbox_parallel_not( types = g.types() if not (h in g.vertices() and n in g.vertices()): return False - if types[h] != VertexType.H_BOX or phases[h] != 1: return False + if types[h] != VertexType.H_BOX or not is_standard_hbox(g, h): return False if g.vertex_degree(n) != 2 or phases[n] != 1: return False # If it turns out to be useful, this rule can be generalised to allow spiders of arbitrary phase here diff --git a/pyzx/rewrite_rules/hpivot_rule.py b/pyzx/rewrite_rules/hpivot_rule.py index a26bc422..6e87e92e 100644 --- a/pyzx/rewrite_rules/hpivot_rule.py +++ b/pyzx/rewrite_rules/hpivot_rule.py @@ -25,7 +25,7 @@ from fractions import Fraction from itertools import combinations from typing import List, Tuple, Optional -from pyzx.utils import VertexType, toggle_edge, FractionLike, FloatInt +from pyzx.utils import VertexType, toggle_edge, FractionLike, FloatInt, is_standard_hbox from pyzx.graph.base import BaseGraph, ET, VT @@ -79,7 +79,7 @@ def match_hpivot( (vertices is None or (vertices[0] == h)) and g.vertex_degree(h) == 2 and types[h] == VertexType.H_BOX and - phases[h] == 1 + is_standard_hbox(g, h) ): continue v0, v1 = g.neighbors(h) @@ -94,10 +94,9 @@ def match_hpivot( v1b = [v for v in v1n if types[v] == VertexType.BOUNDARY] v1h = [v for v in v1n if types[v] == VertexType.H_BOX and v != h] - # check that at least one of v0 or v1 has all pi phases on adjacent - # hboxes. - if not (all(phases[v] == 1 for v in v0h)): - if not (all(phases[v] == 1 for v in v1h)): + # check that at least one of v0 or v1 has all standard H-boxes adjacent. + if not (all(is_standard_hbox(g, v) for v in v0h)): + if not (all(is_standard_hbox(g, v) for v in v1h)): continue else: # interchange the roles of v0 <-> v1 diff --git a/pyzx/rewrite_rules/par_hbox_rule.py b/pyzx/rewrite_rules/par_hbox_rule.py index acf1d1b8..8beed7de 100644 --- a/pyzx/rewrite_rules/par_hbox_rule.py +++ b/pyzx/rewrite_rules/par_hbox_rule.py @@ -28,7 +28,7 @@ from typing import Dict, List, Tuple, Optional, Set, FrozenSet -from pyzx.utils import EdgeType, VertexType +from pyzx.utils import EdgeType, VertexType, hbox_has_complex_label from pyzx.graph.base import BaseGraph, ET, VT @@ -68,6 +68,7 @@ def match_par_hbox( ty = g.types() for h in candidates: if ty[h] != VertexType.H_BOX: continue + if hbox_has_complex_label(g, h): continue suitable = True neighbors_regular = set() neighbors_NOT = set() @@ -164,6 +165,7 @@ def match_par_hbox_intro(g: BaseGraph[VT, ET], vertices: Optional[List[VT]]=None ty = g.types() for h in candidates: if ty[h] != VertexType.H_BOX: continue + if hbox_has_complex_label(g, h): continue suitable = True neighbors_regular = set() neighbors_NOT = set() diff --git a/pyzx/rewrite_rules/push_pauli_rule.py b/pyzx/rewrite_rules/push_pauli_rule.py index 24bb9a5c..02f743c6 100644 --- a/pyzx/rewrite_rules/push_pauli_rule.py +++ b/pyzx/rewrite_rules/push_pauli_rule.py @@ -32,7 +32,7 @@ from typing import List, Dict, Tuple -from pyzx.utils import EdgeType, VertexType, FractionLike, phase_is_pauli, vertex_is_zx, toggle_vertex +from pyzx.utils import EdgeType, VertexType, FractionLike, phase_is_pauli, vertex_is_zx, toggle_vertex, is_standard_hbox from pyzx.graph.base import BaseGraph, VT, ET, upair def check_pauli(g: BaseGraph[VT,ET], v: VT, w: VT) -> bool: @@ -58,7 +58,7 @@ def check_pauli(g: BaseGraph[VT,ET], v: VT, w: VT) -> bool: if ((types[v] == types[w] and et == EdgeType.HADAMARD) or (vertex_is_zx(types[v]) and types[v] != types[w] and et == EdgeType.SIMPLE) or - (types[v] == VertexType.H_BOX and phases[v] == 1 and ( + (types[v] == VertexType.H_BOX and is_standard_hbox(g, v) and ( (et == EdgeType.SIMPLE and types[w] == VertexType.X) or (et == EdgeType.HADAMARD and types[w] == VertexType.Z))) ): diff --git a/pyzx/rewrite_rules/zero_hbox_rule.py b/pyzx/rewrite_rules/zero_hbox_rule.py index ad52065e..91327d1b 100644 --- a/pyzx/rewrite_rules/zero_hbox_rule.py +++ b/pyzx/rewrite_rules/zero_hbox_rule.py @@ -27,16 +27,19 @@ 'unsafe_zero_hbox'] -from pyzx.utils import VertexType +import cmath +from pyzx.utils import VertexType, get_h_box_label, hbox_has_complex_label from pyzx.graph.base import BaseGraph, ET, VT def check_zero_hbox(g: BaseGraph[VT,ET], v:VT) -> bool: - """Matches H-boxes that have a phase of 2pi==0.""" + """Matches H-boxes with label 1 (or phase 0).""" types = g.types() - phases = g.phases() - if types[v] == VertexType.H_BOX and phases[v] == 0: return True - return False + if types[v] != VertexType.H_BOX: + return False + if hbox_has_complex_label(g, v): + return cmath.isclose(get_h_box_label(g, v), 1) + return g.phase(v) == 0 def zero_hbox(g: BaseGraph[VT,ET], v: VT) -> bool: diff --git a/pyzx/tensor.py b/pyzx/tensor.py index 82d29b6a..aa5c2a67 100644 --- a/pyzx/tensor.py +++ b/pyzx/tensor.py @@ -72,9 +72,12 @@ def X_to_tensor(arity: int, phase: float) -> np.ndarray: m[i] -= np.exp(1j*phase) return np.power(np.sqrt(0.5),arity)*m.reshape([2]*arity) -def H_to_tensor(arity: int, phase: float) -> np.ndarray: +def H_to_tensor(arity: int, phase: float, label: Optional[complex] = None) -> np.ndarray: m = np.ones(2**arity, dtype = complex) - if phase != 0: m[-1] = np.exp(1j*phase) + if label is not None: + m[-1] = label + elif phase != 0: + m[-1] = np.exp(1j*phase) return m.reshape([2]*arity) def W_to_tensor(arity: int) -> np.ndarray: @@ -184,7 +187,12 @@ def tensorfy_naive(g: 'BaseGraph[VT,ET]', preserve_scalar: bool = True) -> NDArr elif types[v] == VertexType.X: t = X_to_tensor(d,phase) elif types[v] == VertexType.H_BOX: - t = H_to_tensor(d,phase) + # Check if H-box has a complex label. + h_label = g.vdata(v, 'label', None) + if h_label is not None: + t = H_to_tensor(d, 0, label=complex(h_label)) + else: + t = H_to_tensor(d, phase) elif types[v] == VertexType.W_INPUT or types[v] == VertexType.W_OUTPUT: if phase != 0: raise ValueError("Phase on W node") t = W_to_tensor(d) diff --git a/pyzx/tikz.py b/pyzx/tikz.py index 7ef30899..e7f5afa4 100644 --- a/pyzx/tikz.py +++ b/pyzx/tikz.py @@ -27,7 +27,9 @@ from fractions import Fraction from typing import List, Dict, overload, Tuple, Union, Optional -from .utils import get_z_box_label, set_z_box_label, settings, EdgeType, VertexType, FloatInt +import cmath +from .utils import (get_z_box_label, set_z_box_label, get_h_box_label, set_h_box_label, + hbox_has_complex_label, settings, EdgeType, VertexType, FloatInt) from .graph.base import BaseGraph, VT, ET from .graph.graph import Graph from .circuit import Circuit @@ -74,6 +76,8 @@ def _to_tikz(g: BaseGraph[VT,ET], draw_scalar:bool = False, continue if ty == VertexType.Z_BOX: p = get_z_box_label(g,v) + elif ty == VertexType.H_BOX and hbox_has_complex_label(g, v): + p = get_h_box_label(g, v) else: p = g.phase(v) if ty == VertexType.BOUNDARY: @@ -93,7 +97,14 @@ def _to_tikz(g: BaseGraph[VT,ET], draw_scalar:bool = False, else: if ty==VertexType.Z: style = settings.tikz_classes['Z'] else: style = settings.tikz_classes['X'] - if ((ty == VertexType.H_BOX or ty == VertexType.Z_BOX) and p == 1) or\ + # Determine whether to display the phase/label. + if ty == VertexType.H_BOX and hbox_has_complex_label(g, v): + # For H-boxes with complex labels, hide if standard (-1). + if cmath.isclose(p, -1): + phase = "" + else: + phase = r"$%s$" % str(p) + elif ((ty == VertexType.H_BOX or ty == VertexType.Z_BOX) and p == 1) or\ (ty != VertexType.H_BOX and p == 0): phase = "" elif type(p) == Fraction: @@ -320,12 +331,29 @@ def handle_phase_error(msg: str) -> None: elif label == r'\neg': set_phase(v,1) elif label: - if label.find('pi') == -1 and ty != VertexType.Z_BOX: - if not ignore_nonzx and not ignore_invalid_phases: - raise ValueError("Node definition %s has invalid phase label" % l) - elif ignore_invalid_phases: - set_phase(v, default_phase) - else: + # Check if label might be a complex number for H-box or Z-box. + is_complex_label = ty in (VertexType.Z_BOX, VertexType.H_BOX) and label.find('pi') == -1 + if is_complex_label: + # Try to parse as complex number. + try: + complex_val = complex(label) + if ty == VertexType.H_BOX: + set_h_box_label(g, v, complex_val) + else: + set_phase(v, complex_val) + continue + except ValueError: + # Not a valid complex, fall through to standard parsing. + if ty == VertexType.Z_BOX: + pass # Z-box will be handled below. + elif not ignore_nonzx or ignore_invalid_phases: + handle_phase_error("Node definition %s has invalid phase label" % l) + continue + elif label.find('pi') == -1: + if not ignore_nonzx or ignore_invalid_phases: + handle_phase_error("Node definition %s has invalid phase label" % l) + continue + if label.find('pi') != -1 or ty == VertexType.Z_BOX: label = label.replace(r'\pi','').strip() if label == '' or label == '-' or label == '-1': set_phase(v,1) diff --git a/pyzx/utils.py b/pyzx/utils.py index cc1150ce..adbc6185 100644 --- a/pyzx/utils.py +++ b/pyzx/utils.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import cmath +import math import os from argparse import ArgumentTypeError from enum import IntEnum @@ -307,5 +309,35 @@ def set_z_box_label(g, v, label): assert g.type(v) == VertexType.Z_BOX g.set_vdata(v, 'label', label) + +def get_h_box_label(g, v) -> complex: + assert g.type(v) == VertexType.H_BOX + label = g.vdata(v, 'label', None) + if label is not None: + return complex(label) + phase = g.phase(v) + if isinstance(phase, Poly): + raise ValueError("Cannot convert symbolic phase to complex label") + return cmath.exp(1j * math.pi * float(phase)) + +def set_h_box_label(g, v, label: complex) -> None: + assert g.type(v) == VertexType.H_BOX + g.set_vdata(v, 'label', complex(label)) + g.set_phase(v, 0) + +def is_standard_hbox(g, v) -> bool: + """Check if H-box has the standard Hadamard label (-1).""" + assert g.type(v) == VertexType.H_BOX + label = g.vdata(v, 'label', None) + if label is not None: + return cmath.isclose(label, -1) + return g.phase(v) == 1 + +def hbox_has_complex_label(g, v) -> bool: + """Check if H-box uses a complex label instead of legacy phase.""" + assert g.type(v) == VertexType.H_BOX + return g.vdata(v, 'label', None) is not None + + # Return position 'perc'%-distance between 2 points: def ave_pos(a,b,perc=1/2): return (abs(a-b))*(perc) + min(a,b) diff --git a/tests/test_hbox_cancel.py b/tests/test_hbox.py similarity index 65% rename from tests/test_hbox_cancel.py rename to tests/test_hbox.py index 3dd6333e..8a7c0845 100644 --- a/tests/test_hbox_cancel.py +++ b/tests/test_hbox.py @@ -26,8 +26,11 @@ from fractions import Fraction from pyzx.graph import Graph -from pyzx.utils import EdgeType, VertexType +from pyzx.utils import (EdgeType, VertexType, get_h_box_label, set_h_box_label, + is_standard_hbox, hbox_has_complex_label) from pyzx.rewrite_rules.hbox_cancel_rule import check_hbox_cancel, hbox_cancel +from pyzx.rewrite_rules.zero_hbox_rule import check_zero_hbox +from pyzx.rewrite_rules.copy_rule import check_copy from pyzx.hsimplify import hbox_cancel_simp np: Optional[ModuleType] @@ -282,5 +285,140 @@ def test_hbox_cancel_simp(self): self.assertGreater(rewrites, 0) self.assertEqual(g.num_vertices(), 2) + # Tests for H-box label helper functions. + + def test_get_h_box_label_with_complex_label(self): + """Test get_h_box_label returns the stored complex label.""" + g = Graph() + v = g.add_vertex(VertexType.H_BOX, 0, 0) + g.set_vdata(v, 'label', 1+2j) + self.assertEqual(get_h_box_label(g, v), 1+2j) + + def test_get_h_box_label_phase_fallback(self): + """Test get_h_box_label converts phase to complex when no label set.""" + g = Graph() + v = g.add_vertex(VertexType.H_BOX, 0, 0) + g.set_phase(v, 0) # exp(0) = 1 + self.assertAlmostEqual(get_h_box_label(g, v), 1, places=10) + g.set_phase(v, Fraction(1, 2)) # exp(i*pi/2) = i + self.assertAlmostEqual(get_h_box_label(g, v), 1j, places=10) + g.set_phase(v, 1) # exp(i*pi) = -1 + self.assertAlmostEqual(get_h_box_label(g, v), -1, places=10) + + def test_set_h_box_label(self): + """Test set_h_box_label stores the complex label.""" + g = Graph() + v = g.add_vertex(VertexType.H_BOX, 0, 0) + set_h_box_label(g, v, 3+4j) + self.assertEqual(g.vdata(v, 'label'), 3+4j) + + def test_is_standard_hbox_with_label(self): + """Test is_standard_hbox with label=-1.""" + g = Graph() + v = g.add_vertex(VertexType.H_BOX, 0, 0) + set_h_box_label(g, v, -1) + self.assertTrue(is_standard_hbox(g, v)) + set_h_box_label(g, v, 2+3j) + self.assertFalse(is_standard_hbox(g, v)) + + def test_is_standard_hbox_with_phase(self): + """Test is_standard_hbox with legacy phase.""" + g = Graph() + v = g.add_vertex(VertexType.H_BOX, 0, 0) + g.set_phase(v, 1) + self.assertTrue(is_standard_hbox(g, v)) + g.set_phase(v, Fraction(1, 2)) + self.assertFalse(is_standard_hbox(g, v)) + + def test_hbox_has_complex_label(self): + """Test hbox_has_complex_label detection.""" + g = Graph() + v = g.add_vertex(VertexType.H_BOX, 0, 0) + g.set_phase(v, 1) + self.assertFalse(hbox_has_complex_label(g, v)) + set_h_box_label(g, v, 1j) + self.assertTrue(hbox_has_complex_label(g, v)) + + # Tests for rewrite rules with H-box labels. + + def test_check_hbox_cancel_with_standard_label(self): + """Test that hbox_cancel works with standard label (-1).""" + g = Graph() + v0 = g.add_vertex(VertexType.BOUNDARY, 0, 0) + v1 = g.add_vertex(VertexType.H_BOX, 0, 1) + set_h_box_label(g, v1, -1) + v2 = g.add_vertex(VertexType.H_BOX, 0, 2) + set_h_box_label(g, v2, -1) + v3 = g.add_vertex(VertexType.BOUNDARY, 0, 3) + g.add_edge((v0, v1)) + g.add_edge((v1, v2)) + g.add_edge((v2, v3)) + + self.assertTrue(check_hbox_cancel(g, v1)) + self.assertTrue(check_hbox_cancel(g, v2)) + + def test_check_hbox_cancel_with_nonstandard_label(self): + """Test that hbox_cancel doesn't apply to non-standard labels.""" + g = Graph() + v0 = g.add_vertex(VertexType.BOUNDARY, 0, 0) + v1 = g.add_vertex(VertexType.H_BOX, 0, 1) + set_h_box_label(g, v1, 1+2j) + v2 = g.add_vertex(VertexType.BOUNDARY, 0, 2) + g.add_edge((v0, v1)) + g.add_edge((v1, v2)) + + self.assertFalse(check_hbox_cancel(g, v1)) + + def test_check_zero_hbox_with_label(self): + """Test that zero_hbox detects label=1 (all-ones tensor).""" + g = Graph() + v = g.add_vertex(VertexType.H_BOX, 0, 0) + set_h_box_label(g, v, 1) + self.assertTrue(check_zero_hbox(g, v)) + set_h_box_label(g, v, -1) + self.assertFalse(check_zero_hbox(g, v)) + + def test_copy_rule_standard_hbox(self): + """Test that copy rule applies to standard H-boxes.""" + g = Graph() + # X spider with phase 0 connected to standard H-box + v = g.add_vertex(VertexType.X, 0, 0) + g.set_phase(v, 0) + h = g.add_vertex(VertexType.H_BOX, 0, 1) + g.set_phase(h, 1) # Standard Hadamard + z = g.add_vertex(VertexType.Z, 0, 2) + g.add_edge((v, h)) + g.add_edge((h, z)) + + self.assertTrue(check_copy(g, v)) + + def test_copy_rule_nonstandard_hbox(self): + """Test that copy rule doesn't apply to non-standard H-boxes.""" + g = Graph() + # X spider with phase 0 connected to non-standard H-box + v = g.add_vertex(VertexType.X, 0, 0) + g.set_phase(v, 0) + h = g.add_vertex(VertexType.H_BOX, 0, 1) + set_h_box_label(g, h, 2+3j) # Non-standard label + z = g.add_vertex(VertexType.Z, 0, 2) + g.add_edge((v, h)) + g.add_edge((h, z)) + + self.assertFalse(check_copy(g, v)) + + def test_copy_rule_standard_hbox_with_label(self): + """Test that copy rule applies to H-boxes with label=-1.""" + g = Graph() + v = g.add_vertex(VertexType.X, 0, 0) + g.set_phase(v, 0) + h = g.add_vertex(VertexType.H_BOX, 0, 1) + set_h_box_label(g, h, -1) # Standard Hadamard via label + z = g.add_vertex(VertexType.Z, 0, 2) + g.add_edge((v, h)) + g.add_edge((h, z)) + + self.assertTrue(check_copy(g, v)) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_jsonparser.py b/tests/test_jsonparser.py index bdf33680..6b5c862c 100644 --- a/tests/test_jsonparser.py +++ b/tests/test_jsonparser.py @@ -26,7 +26,7 @@ from pyzx.graph import Graph from pyzx.graph.scalar import Scalar from pyzx.graph.jsonparser import graph_to_dict, dict_to_graph -from pyzx.utils import EdgeType, VertexType +from pyzx.utils import EdgeType, VertexType, set_h_box_label, get_h_box_label, hbox_has_complex_label from pyzx.symbolic import Poly, new_var @@ -381,5 +381,65 @@ def test_zbox_label_roundtrip(self): self.assertEqual(g3.vdata(v1, 'name'), 'my vertex') self.assertEqual(g3.vdata(v2, 'custom'), 42) + def test_hbox_label_roundtrip(self): + """Test JSON round-trip for H-box complex labels.""" + g = Graph() + v1 = g.add_vertex(VertexType.H_BOX, 0, 0) + v2 = g.add_vertex(VertexType.H_BOX, 0, 1) + v3 = g.add_vertex(VertexType.H_BOX, 0, 2) + + set_h_box_label(g, v1, -1) # Standard Hadamard + set_h_box_label(g, v2, 1j) # Complex label + set_h_box_label(g, v3, 2.5 + 1.3j) # Another complex label + + d = graph_to_dict(g) + g2 = dict_to_graph(d) + + self.assertTrue(hbox_has_complex_label(g2, v1)) + self.assertTrue(hbox_has_complex_label(g2, v2)) + self.assertTrue(hbox_has_complex_label(g2, v3)) + self.assertEqual(get_h_box_label(g2, v1), -1) + self.assertEqual(get_h_box_label(g2, v2), 1j) + self.assertEqual(get_h_box_label(g2, v3), 2.5 + 1.3j) + + js = json.dumps(d) + d2 = json.loads(js) + g3 = dict_to_graph(d2) + + self.assertTrue(hbox_has_complex_label(g3, v1)) + self.assertTrue(hbox_has_complex_label(g3, v2)) + self.assertTrue(hbox_has_complex_label(g3, v3)) + self.assertEqual(get_h_box_label(g3, v1), -1) + self.assertEqual(get_h_box_label(g3, v2), 1j) + self.assertEqual(get_h_box_label(g3, v3), 2.5 + 1.3j) + + def test_hbox_label_tikz_roundtrip(self): + """Test tikz round-trip for H-box complex labels.""" + g = Graph() + v1 = g.add_vertex(VertexType.H_BOX, 0, 0) + v2 = g.add_vertex(VertexType.H_BOX, 0, 1) + v3 = g.add_vertex(VertexType.H_BOX, 0, 2) + + set_h_box_label(g, v1, -1) # Standard Hadamard + set_h_box_label(g, v2, 1j) # Complex label + set_h_box_label(g, v3, 2.5+1.3j) # Another complex label + + tikz = g.to_tikz() + g2 = Graph.from_tikz(tikz, warn_overlap=False) + + # Find corresponding vertices in g2 by position. + v1_new = [v for v in g2.vertices() if g2.row(v) == 0][0] + v2_new = [v for v in g2.vertices() if g2.row(v) == 1][0] + v3_new = [v for v in g2.vertices() if g2.row(v) == 2][0] + + # Standard Hadamard (-1) exports as empty and imports as legacy phase=1. + # Check semantic equivalence rather than format preservation. + self.assertAlmostEqual(get_h_box_label(g2, v1_new), -1, places=10) + # Non-standard labels should preserve exact format. + self.assertTrue(hbox_has_complex_label(g2, v2_new)) + self.assertTrue(hbox_has_complex_label(g2, v3_new)) + self.assertEqual(get_h_box_label(g2, v2_new), 1j) + self.assertEqual(get_h_box_label(g2, v3_new), 2.5+1.3j) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_tensor.py b/tests/test_tensor.py index e4663158..552cbbcd 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -24,15 +24,18 @@ if __name__ == '__main__': sys.path.append('..') sys.path.append('.') +import math + from pyzx.graph import Graph from pyzx.graph.multigraph import Multigraph from pyzx.generate import cliffords from pyzx.circuit import Circuit +from pyzx.utils import VertexType, set_h_box_label np: Optional[ModuleType] try: import numpy as np - from pyzx.tensor import tensorfy, compare_tensors, compose_tensors, adjoint, VertexType + from pyzx.tensor import tensorfy, compare_tensors, compose_tensors, adjoint, H_to_tensor except ImportError: np = None @@ -189,6 +192,76 @@ def test_to_tensor_equivalent(self): g1.add_vertex(VertexType.X, phase=1) self.assertTrue(g.to_tensor() == g1.to_tensor()) + def test_h_to_tensor_with_label(self): + """Test H_to_tensor with explicit complex label.""" + t = H_to_tensor(2, 0, label=3+4j) + expected = np.array([[1, 1], [1, 3+4j]]) + self.assertTrue(np.allclose(t, expected)) + + def test_h_to_tensor_standard_hadamard(self): + """Test H_to_tensor for standard Hadamard (label=-1 or phase=pi).""" + t_label = H_to_tensor(2, 0, label=-1) + t_phase = H_to_tensor(2, math.pi) + expected = np.array([[1, 1], [1, -1]]) + self.assertTrue(np.allclose(t_label, expected)) + self.assertTrue(np.allclose(t_phase, expected)) + + def test_tensorfy_hbox_with_complex_label(self): + """Test tensorfy with H-box having complex label.""" + g = Graph() + i = g.add_vertex(VertexType.BOUNDARY, 0, 0) + h = g.add_vertex(VertexType.H_BOX, 0, 1) + o = g.add_vertex(VertexType.BOUNDARY, 0, 2) + g.set_inputs((i,)) + g.set_outputs((o,)) + g.add_edge((i, h)) + g.add_edge((h, o)) + set_h_box_label(g, h, 1j) + + t = tensorfy(g) + expected = np.array([[1, 1], [1, 1j]]) + self.assertTrue(np.allclose(t, expected)) + + def test_tensorfy_hbox_with_standard_label(self): + """Test tensorfy with H-box having standard label -1.""" + g = Graph() + i = g.add_vertex(VertexType.BOUNDARY, 0, 0) + h = g.add_vertex(VertexType.H_BOX, 0, 1) + o = g.add_vertex(VertexType.BOUNDARY, 0, 2) + g.set_inputs((i,)) + g.set_outputs((o,)) + g.add_edge((i, h)) + g.add_edge((h, o)) + set_h_box_label(g, h, -1) + + t = tensorfy(g) + expected = np.array([[1, 1], [1, -1]]) + self.assertTrue(np.allclose(t, expected)) + + def test_tensorfy_hbox_phase_and_label_equivalence(self): + """Test that phase=1 and label=-1 produce same tensor.""" + g1 = Graph() + i1 = g1.add_vertex(VertexType.BOUNDARY, 0, 0) + h1 = g1.add_vertex(VertexType.H_BOX, 0, 1) + o1 = g1.add_vertex(VertexType.BOUNDARY, 0, 2) + g1.set_inputs((i1,)) + g1.set_outputs((o1,)) + g1.add_edge((i1, h1)) + g1.add_edge((h1, o1)) + g1.set_phase(h1, 1) + + g2 = Graph() + i2 = g2.add_vertex(VertexType.BOUNDARY, 0, 0) + h2 = g2.add_vertex(VertexType.H_BOX, 0, 1) + o2 = g2.add_vertex(VertexType.BOUNDARY, 0, 2) + g2.set_inputs((i2,)) + g2.set_outputs((o2,)) + g2.add_edge((i2, h2)) + g2.add_edge((h2, o2)) + set_h_box_label(g2, h2, -1) + + self.assertTrue(compare_tensors(g1, g2, preserve_scalar=True)) + if __name__ == '__main__': unittest.main()