Skip to content

Commit

Permalink
better type annotations in BaseGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
akissinger committed Dec 22, 2024
1 parent 4dac7b4 commit e009966
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions pyzx/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import abc
import math
import copy
Expand Down Expand Up @@ -124,7 +125,7 @@ def stats(self) -> str:
s += "{:d}: {:d}\n".format(d,n)
return s

def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph':
def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> BaseGraph[VT,ET]:
"""Create a copy of the graph. If ``adjoint`` is set,
the adjoint of the graph will be returned (inputs and outputs flipped, phases reversed).
When ``backend`` is set, a copy of the graph with the given backend is produced.
Expand All @@ -151,7 +152,7 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph':
# mypy issue https://github.com/python/mypy/issues/16413
g.track_phases = self.track_phases
g.scalar = self.scalar.copy(conjugate=adjoint)
g.merge_vdata = self.merge_vdata
g.merge_vdata = self.merge_vdata # type: ignore
mult:int = 1
if adjoint:
mult = -1
Expand Down Expand Up @@ -190,11 +191,11 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph':

return g

def adjoint(self) -> 'BaseGraph':
def adjoint(self) -> BaseGraph[VT,ET]:
"""Returns a new graph equal to the adjoint of this graph."""
return self.copy(adjoint=True)

def clone(self) -> 'BaseGraph':
def clone(self) -> BaseGraph[VT,ET]:
"""
This method should return an identical copy of the graph, without any relabeling.
Expand All @@ -215,7 +216,7 @@ def map_qubits(self, qubit_map:Mapping[int,Tuple[float,float]]) -> None:
self.set_row(v, rf)


# def replace_subgraph(self, left_row: FloatInt, right_row: FloatInt, replace: 'BaseGraph') -> None:
# def replace_subgraph(self, left_row: FloatInt, right_row: FloatInt, replace: BaseGraph[VT,ET]) -> None:
# """Deletes the subgraph of all nodes with rank strictly between ``left_row``
# and ``right_row`` and replaces it with the graph ``replace``.
# The amount of nodes on the left row should match the amount of inputs of
Expand Down Expand Up @@ -260,7 +261,7 @@ def map_qubits(self, qubit_map:Mapping[int,Tuple[float,float]]) -> None:
# for e,f in etab.items():
# self.set_edge_type(f, replace.edge_type(e))

def compose(self, other: 'BaseGraph') -> None:
def compose(self, other: BaseGraph[VT,ET]) -> None:
"""Inserts a graph after this one. The amount of qubits of the graphs must match.
Also available by the operator `graph1 + graph2`"""
other = other.copy()
Expand Down Expand Up @@ -314,7 +315,7 @@ def compose(self, other: 'BaseGraph') -> None:



def tensor(self, other: 'BaseGraph') -> 'BaseGraph':
def tensor(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
"""Take the tensor product of two graphs. Places the second graph below the first one.
Can also be called using the operator ``graph1 @ graph2``"""
g = self.copy()
Expand Down Expand Up @@ -349,25 +350,25 @@ def tensor(self, other: 'BaseGraph') -> 'BaseGraph':

return g

def __iadd__(self, other: 'BaseGraph') -> 'BaseGraph':
def __iadd__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
self.compose(other)
return self

def __add__(self, other: 'BaseGraph') -> 'BaseGraph':
def __add__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
g = self.copy()
g += other
return g

def __mul__(self, other: 'BaseGraph') -> 'BaseGraph':
def __mul__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
"""Compose two diagrams, in formula order. That is, g * h produces 'g AFTER h'."""
g = other.copy()
g.compose(self)
return g

def __matmul__(self, other: 'BaseGraph') -> 'BaseGraph':
def __matmul__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
return self.tensor(other)

def merge(self, other: 'BaseGraph') -> Tuple[List[VT],List[ET]]:
def merge(self, other: BaseGraph[VT,ET]) -> Tuple[List[VT],List[ET]]:
"""Merges this graph with the other graph in-place.
Returns (list-of-vertices, list-of-edges) corresponding to
the id's of the vertices and edges of the other graph."""
Expand All @@ -389,7 +390,7 @@ def merge(self, other: 'BaseGraph') -> Tuple[List[VT],List[ET]]:
edges.append(e)
return (list(vert_map.values()),edges)

def subgraph_from_vertices(self,verts: List[VT]) -> 'BaseGraph':
def subgraph_from_vertices(self,verts: List[VT]) -> BaseGraph[VT,ET]:
"""Returns the subgraph consisting of the specified vertices."""
from .graph import Graph # imported here to prevent circularity
from .multigraph import Multigraph
Expand Down Expand Up @@ -501,14 +502,14 @@ def to_tikz(self,draw_scalar:bool=False) -> str:
return to_tikz(self,draw_scalar)

@classmethod
def from_json(cls, js:Union[str,Dict[str,Any]]) -> 'BaseGraph':
def from_json(cls, js:Union[str,Dict[str,Any]]) -> BaseGraph[VT,ET]:
"""Converts the given .qgraph json string into a Graph.
Works with the output of :meth:`to_json`."""
from .jsonparser import json_to_graph
return json_to_graph(js,cls.backend)

@classmethod
def from_tikz(cls, tikz: str, warn_overlap:bool= True, fuse_overlap:bool = True, ignore_nonzx:bool = False) -> 'BaseGraph':
def from_tikz(cls, tikz: str, warn_overlap:bool= True, fuse_overlap:bool = True, ignore_nonzx:bool = False) -> BaseGraph[VT,ET]:
"""Converts a tikz diagram into a pyzx Graph.
The tikz diagram is assumed to be one generated by Tikzit,
and hence should have a nodelayer and a edgelayer..
Expand Down Expand Up @@ -637,7 +638,7 @@ def normalize(self) -> None:

self.pack_circuit_rows()

def translate(self, x:FloatInt, y:FloatInt) -> 'BaseGraph':
def translate(self, x:FloatInt, y:FloatInt) -> BaseGraph[VT,ET]:
g = self.copy()
for v in g.vertices():
g.set_row(v, g.row(v)+x)
Expand Down

0 comments on commit e009966

Please sign in to comment.