Skip to content

Commit

Permalink
Merge pull request #241 from RazinShaikh/multigraph-counter-diff
Browse files Browse the repository at this point in the history
Fix several multigraph bugs including counter and graphdiff
  • Loading branch information
jvdwetering authored Jun 27, 2024
2 parents b2bb047 + 32f7ef7 commit cd0269d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
11 changes: 10 additions & 1 deletion pyzx/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,13 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph':
graph did not.
"""
from .graph import Graph # imported here to prevent circularity
from .multigraph import Multigraph
if (backend is None):
backend = type(self).backend
g = Graph(backend = backend)
if isinstance(self, Multigraph) and isinstance(g, Multigraph):
g.set_auto_simplify(self._auto_simplify) # type: ignore
# mypy issue https://github.com/python/mypy/issues/16413
g.track_phases = self.track_phases
g.scalar = self.scalar.copy(conjugate=adjoint)
g.merge_vdata = self.merge_vdata
Expand Down Expand Up @@ -390,14 +394,19 @@ def merge(self, other: 'BaseGraph') -> Tuple[List[VT],List[ET]]:
def subgraph_from_vertices(self,verts: List[VT]) -> 'BaseGraph':
"""Returns the subgraph consisting of the specified vertices."""
from .graph import Graph # imported here to prevent circularity
from .multigraph import Multigraph
g = Graph(backend=type(self).backend)
if isinstance(self, Multigraph) and isinstance(g, Multigraph):
g.set_auto_simplify(self._auto_simplify) # type: ignore
# mypy issue https://github.com/python/mypy/issues/16413
ty = self.types()
rs = self.rows()
qs = self.qubits()
phase = self.phases()
grounds = self.grounds()

edges = [self.edge(v,w) for v in verts for w in verts if self.connected(v,w)]
edges = [e for e in self.edges() \
if self.edge_st(e)[0] in verts and self.edge_st(e)[1] in verts]

vert_map = dict()
for v in verts:
Expand Down
10 changes: 5 additions & 5 deletions pyzx/graph/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
from collections import Counter
from typing import Any, Callable, Generic, Optional, List, Dict, Tuple
import copy

from ..utils import VertexType, EdgeType, FractionLike, FloatInt, phase_to_s
from .base import BaseGraph, VT, ET
Expand Down Expand Up @@ -56,12 +57,11 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:
self.new_edges = []
self.removed_edges = []

for e in (new_edges - old_edges):
for e in Counter(new_edges - old_edges).elements():
self.new_edges.append((g2.edge_st(e), g2.edge_type(e)))

for e in (old_edges - new_edges):
for e in Counter(old_edges - new_edges).elements():
s,t = g1.edge_st(e)
if s in self.removed_verts or t in self.removed_verts: continue
self.removed_edges.append(e)

for v in new_verts:
Expand Down Expand Up @@ -94,8 +94,8 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:

def apply_diff(self,g: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
g = copy.deepcopy(g)
g.remove_vertices(self.removed_verts)
g.remove_edges(self.removed_edges)
g.remove_vertices(self.removed_verts)
for v in self.new_verts:
g.add_vertex_indexed(v)
g.set_position(v,*self.changed_pos[v])
Expand Down
4 changes: 2 additions & 2 deletions pyzx/graph/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def clone(self) -> 'Multigraph':
cpy.phase_mult = self.phase_mult.copy()
cpy.max_phase_index = self.max_phase_index
return cpy

def set_auto_simplify(self, s: bool):
"""Automatically remove parallel edges as edges are added"""
self._auto_simplify = s

def multigraph(self):
return False

Expand Down

0 comments on commit cd0269d

Please sign in to comment.