From e0099666f01081ec1093c0b4a41569dd41d54d1c Mon Sep 17 00:00:00 2001 From: Aleks Kissinger Date: Sun, 22 Dec 2024 14:31:46 +0000 Subject: [PATCH] better type annotations in BaseGraph --- pyzx/graph/base.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index 39782671..811efb2a 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import abc import math import copy @@ -124,7 +125,7 @@ def stats(self) -> str: s += "{:d}: {:d}\n".format(d,n) return s - def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph': + def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> BaseGraph[VT,ET]: """Create a copy of the graph. If ``adjoint`` is set, the adjoint of the graph will be returned (inputs and outputs flipped, phases reversed). When ``backend`` is set, a copy of the graph with the given backend is produced. @@ -151,7 +152,7 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph': # mypy issue https://github.com/python/mypy/issues/16413 g.track_phases = self.track_phases g.scalar = self.scalar.copy(conjugate=adjoint) - g.merge_vdata = self.merge_vdata + g.merge_vdata = self.merge_vdata # type: ignore mult:int = 1 if adjoint: mult = -1 @@ -190,11 +191,11 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph': return g - def adjoint(self) -> 'BaseGraph': + def adjoint(self) -> BaseGraph[VT,ET]: """Returns a new graph equal to the adjoint of this graph.""" return self.copy(adjoint=True) - def clone(self) -> 'BaseGraph': + def clone(self) -> BaseGraph[VT,ET]: """ This method should return an identical copy of the graph, without any relabeling. @@ -215,7 +216,7 @@ def map_qubits(self, qubit_map:Mapping[int,Tuple[float,float]]) -> None: self.set_row(v, rf) - # def replace_subgraph(self, left_row: FloatInt, right_row: FloatInt, replace: 'BaseGraph') -> None: + # def replace_subgraph(self, left_row: FloatInt, right_row: FloatInt, replace: BaseGraph[VT,ET]) -> None: # """Deletes the subgraph of all nodes with rank strictly between ``left_row`` # and ``right_row`` and replaces it with the graph ``replace``. # The amount of nodes on the left row should match the amount of inputs of @@ -260,7 +261,7 @@ def map_qubits(self, qubit_map:Mapping[int,Tuple[float,float]]) -> None: # for e,f in etab.items(): # self.set_edge_type(f, replace.edge_type(e)) - def compose(self, other: 'BaseGraph') -> None: + def compose(self, other: BaseGraph[VT,ET]) -> None: """Inserts a graph after this one. The amount of qubits of the graphs must match. Also available by the operator `graph1 + graph2`""" other = other.copy() @@ -314,7 +315,7 @@ def compose(self, other: 'BaseGraph') -> None: - def tensor(self, other: 'BaseGraph') -> 'BaseGraph': + def tensor(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: """Take the tensor product of two graphs. Places the second graph below the first one. Can also be called using the operator ``graph1 @ graph2``""" g = self.copy() @@ -349,25 +350,25 @@ def tensor(self, other: 'BaseGraph') -> 'BaseGraph': return g - def __iadd__(self, other: 'BaseGraph') -> 'BaseGraph': + def __iadd__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: self.compose(other) return self - def __add__(self, other: 'BaseGraph') -> 'BaseGraph': + def __add__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: g = self.copy() g += other return g - def __mul__(self, other: 'BaseGraph') -> 'BaseGraph': + def __mul__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: """Compose two diagrams, in formula order. That is, g * h produces 'g AFTER h'.""" g = other.copy() g.compose(self) return g - def __matmul__(self, other: 'BaseGraph') -> 'BaseGraph': + def __matmul__(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: return self.tensor(other) - def merge(self, other: 'BaseGraph') -> Tuple[List[VT],List[ET]]: + def merge(self, other: BaseGraph[VT,ET]) -> Tuple[List[VT],List[ET]]: """Merges this graph with the other graph in-place. Returns (list-of-vertices, list-of-edges) corresponding to the id's of the vertices and edges of the other graph.""" @@ -389,7 +390,7 @@ def merge(self, other: 'BaseGraph') -> Tuple[List[VT],List[ET]]: edges.append(e) return (list(vert_map.values()),edges) - def subgraph_from_vertices(self,verts: List[VT]) -> 'BaseGraph': + def subgraph_from_vertices(self,verts: List[VT]) -> BaseGraph[VT,ET]: """Returns the subgraph consisting of the specified vertices.""" from .graph import Graph # imported here to prevent circularity from .multigraph import Multigraph @@ -501,14 +502,14 @@ def to_tikz(self,draw_scalar:bool=False) -> str: return to_tikz(self,draw_scalar) @classmethod - def from_json(cls, js:Union[str,Dict[str,Any]]) -> 'BaseGraph': + def from_json(cls, js:Union[str,Dict[str,Any]]) -> BaseGraph[VT,ET]: """Converts the given .qgraph json string into a Graph. Works with the output of :meth:`to_json`.""" from .jsonparser import json_to_graph return json_to_graph(js,cls.backend) @classmethod - def from_tikz(cls, tikz: str, warn_overlap:bool= True, fuse_overlap:bool = True, ignore_nonzx:bool = False) -> 'BaseGraph': + def from_tikz(cls, tikz: str, warn_overlap:bool= True, fuse_overlap:bool = True, ignore_nonzx:bool = False) -> BaseGraph[VT,ET]: """Converts a tikz diagram into a pyzx Graph. The tikz diagram is assumed to be one generated by Tikzit, and hence should have a nodelayer and a edgelayer.. @@ -637,7 +638,7 @@ def normalize(self) -> None: self.pack_circuit_rows() - def translate(self, x:FloatInt, y:FloatInt) -> 'BaseGraph': + def translate(self, x:FloatInt, y:FloatInt) -> BaseGraph[VT,ET]: g = self.copy() for v in g.vertices(): g.set_row(v, g.row(v)+x)