Skip to content

Commit

Permalink
removed _vdata from BaseGraph. this should be handled by backend
Browse files Browse the repository at this point in the history
  • Loading branch information
akissinger committed Dec 31, 2024
1 parent a811efe commit 09a49a4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 55 deletions.
64 changes: 14 additions & 50 deletions pyzx/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(self) -> None:
self.phase_master: Optional['simplify.Simplifier'] = None
self.phase_mult: Dict[int,Literal[1,-1]] = dict()
self.max_phase_index: int = -1
self._vdata: Dict[VT,Dict[str,Any]] = dict()

# merge_vdata(v0,v1) is an optional, custom function for merging
# vdata of v1 into v0 during spider fusion etc.
Expand Down Expand Up @@ -215,52 +214,6 @@ def map_qubits(self, qubit_map:Mapping[int,Tuple[float,float]]) -> None:
self.set_qubit(v, qf)
self.set_row(v, rf)


# def replace_subgraph(self, left_row: FloatInt, right_row: FloatInt, replace: BaseGraph[VT,ET]) -> None:
# """Deletes the subgraph of all nodes with rank strictly between ``left_row``
# and ``right_row`` and replaces it with the graph ``replace``.
# The amount of nodes on the left row should match the amount of inputs of
# the replacement graph and the same for the right row and the outputs.
# The graphs are glued together based on the qubit index of the vertices."""

# qleft = [v for v in self.vertices() if self.row(v)==left_row]
# qright= [v for v in self.vertices() if self.row(v)==right_row]
# r_inputs = replace.inputs()
# r_outputs = replace.outputs()
# if len(qleft) != len(r_inputs):
# raise TypeError("Inputs do not match glueing vertices")
# if len(qright) != len(r_outputs):
# raise TypeError("Outputs do not match glueing vertices")
# if set(self.qubit(v) for v in qleft) != set(replace.qubit(v) for v in r_inputs):
# raise TypeError("Input qubit indices do not match")
# if set(self.qubit(v) for v in qright)!= set(replace.qubit(v) for v in r_outputs):
# raise TypeError("Output qubit indices do not match")

# self.remove_vertices([v for v in self.vertices() if (left_row < self.row(v) and self.row(v) < right_row)])
# self.remove_edges([self.edge(s,t) for s in qleft for t in qright if self.connected(s,t)])
# rdepth = replace.depth() -1
# for v in (v for v in self.vertices() if self.row(v)>=right_row):
# self.set_row(v, self.row(v)+rdepth)

# vtab = {}
# for v in replace.vertices():
# if v in r_inputs or v in r_outputs: continue
# vtab[v] = self.add_vertex(replace.type(v),
# replace.qubit(v),
# replace.row(v)+left_row,
# replace.phase(v),
# replace.is_ground(v))
# for v in r_inputs:
# vtab[v] = [i for i in qleft if self.qubit(i) == replace.qubit(v)][0]

# for v in r_outputs:
# vtab[v] = [i for i in qright if self.qubit(i) == replace.qubit(v)][0]

# etab = {e:self.edge(vtab[replace.edge_s(e)],vtab[replace.edge_t(e)]) for e in replace.edges()}
# self.add_edges(etab.values())
# for e,f in etab.items():
# self.set_edge_type(f, replace.edge_type(e))

def compose(self, other: BaseGraph[VT,ET]) -> None:
"""Inserts a graph after this one. The amount of qubits of the graphs must match.
Also available by the operator `graph1 + graph2`"""
Expand Down Expand Up @@ -302,7 +255,7 @@ def compose(self, other: BaseGraph[VT,ET]) -> None:
qubit=other.qubit(v),
row=offset + other.row(v),
ground=other.is_ground(v))
if v in other._vdata: self._vdata[w] = other._vdata[v]
self.set_vdata_dict(w, other.vdata_dict(v))
vtab[v] = w
for e in other.edges():
s,t = other.edge_st(e)
Expand All @@ -325,11 +278,10 @@ def tensor(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
height = max((self.qubits().values()), default=0) + 1
rs = other.rows()
phases = other.phases()
vdata = other._vdata
vertex_map = dict()
for v in other.vertices():
w = g.add_vertex(ts[v],qs[v]+height,rs[v],phases[v],g.is_ground(v))
if v in vdata: g._vdata[w] = vdata[v]
g.set_vdata_dict(w, other.vdata_dict(v))
vertex_map[v] = w
for e in other.edges():
s,t = other.edge_st(e)
Expand Down Expand Up @@ -964,6 +916,10 @@ def set_position(self, vertex: VT, q: FloatInt, r: FloatInt):
self.set_qubit(vertex, q)
self.set_row(vertex, r)

def clear_vdata(self, vertex: VT) -> None:
"""Removes all vdata associated to a vertex"""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def vdata_keys(self, vertex: VT) -> Sequence[str]:
"""Returns an iterable of the vertex data key names.
Used e.g. in making a copy of the graph in a backend-independent way."""
Expand All @@ -978,6 +934,14 @@ def set_vdata(self, vertex: VT, key: str, val: Any) -> None:
"""Sets the vertex data associated to key to val."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def vdata_dict(self, vertex: VT) -> Dict[str, Any]:
return { key: self.vdata(vertex, key) for key in self.vdata_keys(vertex) }

def set_vdata_dict(self, vertex: VT, d: Dict[str, Any]) -> None:
self.clear_vdata(vertex)
for k, v in d.items():
self.set_vdata(vertex, k, v)

def is_well_formed(self) -> bool:
"""Returns whether the graph is a well-formed ZX-diagram.
This means that it has no isolated boundary vertices,
Expand Down
11 changes: 6 additions & 5 deletions pyzx/graph/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:
self.new_edges.append((g2.edge_st(e), g2.edge_type(e)))

for e in Counter(old_edges - new_edges).elements():
s,t = g1.edge_st(e)
self.removed_edges.append(e)

for v in new_verts:
Expand All @@ -70,8 +69,10 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:
self.changed_vertex_types[v] = g2.type(v)
if g1.phase(v) != g2.phase(v):
self.changed_phases[v] = g2.phase(v)
if g1._vdata.get(v, None) != g2._vdata.get(v, None):
self.changed_vdata[v] = g2._vdata.get(v, None)
d1 = g1.vdata_dict(v)
d2 = g2.vdata_dict(v)
if d1 != d2:
self.changed_vdata[v] = d2
pos1 = g1.qubit(v), g1.row(v)
pos2 = g2.qubit(v), g2.row(v)
if pos1 != pos2:
Expand Down Expand Up @@ -106,7 +107,7 @@ def apply_diff(self,g: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
if v in self.changed_phases:
g.set_phase(v,self.changed_phases[v])
if v in self.changed_vdata:
g._vdata[v] = self.changed_vdata[v]
g.set_vdata_dict(v, self.changed_vdata[v])
for st, ty in self.new_edges:
g.add_edge(st,ty)

Expand All @@ -124,7 +125,7 @@ def apply_diff(self,g: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:

for v in self.changed_vdata:
if v in self.new_verts: continue
g._vdata[v] = self.changed_vdata[v]
g.set_vdata_dict(v, self.changed_vdata[v])

for e in self.changed_edge_types:
g.set_edge_type(e,self.changed_edge_types[e])
Expand Down
3 changes: 3 additions & 0 deletions pyzx/graph/graph_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def set_ground(self, vertex, flag=True):
else:
self._grounds.discard(vertex)

def clear_vdata(self, vertex):
if vertex in self._vdata:
del self._vdata[vertex]
def vdata_keys(self, vertex):
return self._vdata.get(vertex, {}).keys()
def vdata(self, vertex, key, default=0):
Expand Down
3 changes: 3 additions & 0 deletions pyzx/graph/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ def set_ground(self, vertex, flag=True):
else:
self._grounds.discard(vertex)

def clear_vdata(self, vertex):
if vertex in self._vdata:
del self._vdata[vertex]
def vdata_keys(self, vertex):
return self._vdata.get(vertex, {}).keys()
def vdata(self, vertex, key, default=0):
Expand Down

0 comments on commit 09a49a4

Please sign in to comment.