Skip to content

Commit

Permalink
provided reasonable derived implementations for many methods in BaseG…
Browse files Browse the repository at this point in the history
…raph
  • Loading branch information
akissinger committed Dec 31, 2024
1 parent cf0540f commit 59cccde
Showing 1 changed file with 98 additions and 85 deletions.
183 changes: 98 additions & 85 deletions pyzx/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from fractions import Fraction
from typing import TYPE_CHECKING, Union, Optional, Generic, TypeVar, Any, Sequence
from typing import List, Dict, Set, Tuple, Mapping, Iterable, Callable, ClassVar
from typing_extensions import Literal, GenericMeta # type: ignore # https://github.com/python/mypy/issues/5753
from typing_extensions import Literal, GenericMeta

import numpy as np

Expand Down Expand Up @@ -101,7 +101,9 @@ def __init__(self) -> None:
self.merge_vdata: Optional[Callable[[VT,VT], None]] = None
self.variable_types: Dict[str,bool] = dict() # mapping of variable names to their type (bool or continuous)

# MANDATORY OVERRIDES (ALL BACKENDS) {{{
# MANDATORY OVERRIDES {{{

# All backends should override these methods

def clone(self) -> BaseGraph[VT,ET]:
"""
Expand All @@ -111,11 +113,6 @@ def clone(self) -> BaseGraph[VT,ET]:
"""
raise NotImplementedError()

def vindex(self) -> VT:
"""The index given to the next vertex added to the graph. It should always
be equal to ``max(g.vertices()) + 1``."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def inputs(self) -> Tuple[VT, ...]:
"""Gets the inputs of the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)
Expand All @@ -124,10 +121,6 @@ def set_inputs(self, inputs: Tuple[VT, ...]):
"""Sets the inputs of the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def num_inputs(self) -> int:
"""Gets the number of inputs of the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def outputs(self) -> Tuple[VT, ...]:
"""Gets the outputs of the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)
Expand All @@ -136,13 +129,9 @@ def set_outputs(self, outputs: Tuple[VT, ...]):
"""Sets the outputs of the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def num_outputs(self) -> int:
"""Gets the number of outputs of the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def add_vertices(self, amount: int) -> List[VT]:
"""Add the given amount of vertices, and return the indices of the
new vertices added to the graph, namely: range(g.vindex() - amount, g.vindex())"""
new vertices added to the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def add_vertex_indexed(self,v:VT) -> None:
Expand All @@ -152,10 +141,6 @@ def add_vertex_indexed(self,v:VT) -> None:
which requires vertices to preserve their index."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def add_edges(self, edge_pairs: Iterable[Tuple[VT,VT]], edgetype:EdgeType=EdgeType.SIMPLE) -> None:
"""Adds a list of edges to the graph."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def add_edge(self, edge_pair: Tuple[VT,VT], edgetype:EdgeType=EdgeType.SIMPLE) -> ET:
"""Adds a single edge of the given type and return its id"""
raise NotImplementedError("Not implemented on backend " + type(self).backend)
Expand Down Expand Up @@ -185,48 +170,14 @@ def edges(self, s: Optional[VT]=None, t: Optional[VT]=None) -> Iterable[ET]:
Output type depends on implementation in backend."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def vertex_set(self) -> Set[VT]:
"""Returns the vertices of the graph as a Python set.
Should be overloaded if the backend supplies a cheaper version than this."""
return set(self.vertices())

def edge_set(self) -> Set[ET]:
"""Returns the edges of the graph as a Python set.
Should be overloaded if the backend supplies a cheaper version than this. Note this ignores parallel edges."""
return set(self.edges())

def edge(self, s:VT, t:VT, et: EdgeType=EdgeType.SIMPLE) -> ET:
"""Returns the name of the first edge with the given source/target and type. Behaviour is undefined if the vertices are not connected."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def edge_st(self, edge: ET) -> Tuple[VT, VT]:
"""Returns a tuple of source/target of the given edge."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def edge_s(self, edge: ET) -> VT:
"""Returns the source of the given edge."""
return self.edge_st(edge)[0]

def edge_t(self, edge: ET) -> VT:
"""Returns the target of the given edge."""
return self.edge_st(edge)[1]

def neighbors(self, vertex: VT) -> Sequence[VT]:
"""Returns all neighboring vertices of the given vertex."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def vertex_degree(self, vertex: VT) -> int:
"""Returns the degree of the given vertex."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def incident_edges(self, vertex: VT) -> Sequence[ET]:
"""Returns all neighboring edges of the given vertex."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def connected(self,v1: VT,v2: VT) -> bool:
"""Returns whether vertices v1 and v2 share an edge."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def edge_type(self, e: ET) -> EdgeType:
"""Returns the type of the given edge:
``EdgeType.SIMPLE`` if it is regular, ``EdgeType.HADAMARD`` if it is a Hadamard edge,
Expand All @@ -243,10 +194,6 @@ def type(self, vertex: VT) -> VertexType:
VertexType.X if it is a X node, VertexType.H_BOX if it is an H-box."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def types(self) -> Mapping[VT, VertexType]:
"""Returns a mapping of vertices to their types."""
raise NotImplementedError("Not implemented on backend " + type(self).backend)

def set_type(self, vertex: VT, t: VertexType) -> None:
"""Sets the type of the given vertex to t."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)
Expand All @@ -255,10 +202,6 @@ def phase(self, vertex: VT) -> FractionLike:
"""Returns the phase value of the given vertex."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def phases(self) -> Mapping[VT, FractionLike]:
"""Returns a mapping of vertices to their phase values."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def set_phase(self, vertex: VT, phase: FractionLike) -> None:
"""Sets the phase of the vertex to the given value."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)
Expand All @@ -268,10 +211,6 @@ def qubit(self, vertex: VT) -> FloatInt:
If no index has been set, returns -1."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def qubits(self) -> Mapping[VT,FloatInt]:
"""Returns a mapping of vertices to their qubit index."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def set_qubit(self, vertex: VT, q: FloatInt) -> None:
"""Sets the qubit index associated to the vertex."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)
Expand All @@ -281,23 +220,10 @@ def row(self, vertex: VT) -> FloatInt:
If no row has been set, returns -1."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def rows(self) -> Mapping[VT, FloatInt]:
"""Returns a mapping of vertices to their row index."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def set_row(self, vertex: VT, r: FloatInt) -> None:
"""Sets the row the vertex should be positioned at."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def is_ground(self, vertex: VT) -> bool:
"""Returns a boolean indicating if the vertex is connected to a ground."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)

def set_position(self, vertex: VT, q: FloatInt, r: FloatInt):
"""Set both the qubit index and row index of the vertex."""
self.set_qubit(vertex, q)
self.set_row(vertex, r)

def clear_vdata(self, vertex: VT) -> None:
"""Removes all vdata associated to a vertex"""
raise NotImplementedError("Not implemented on backend" + type(self).backend)
Expand All @@ -323,9 +249,13 @@ def set_vdata(self, vertex: VT, key: str, val: Any) -> None:

# These only need to be overridden if the backend will be used with hybrid classical/quantum
# methods.
def is_ground(self, vertex: VT) -> bool:
"""Returns a boolean indicating if the vertex is connected to a ground."""
return False

def grounds(self) -> Set[VT]:
"""Returns the set of vertices connected to a ground."""
raise NotImplementedError("Not implemented on backend" + type(self).backend)
return set(v for v in self.vertices() if self.is_ground(v))

def set_ground(self, vertex: VT, flag: bool=True) -> None:
"""Connect or disconnect the vertex to a ground."""
Expand All @@ -336,9 +266,31 @@ def is_hybrid(self) -> bool:
i.e. a graph with ground generators."""
return bool(self.grounds())

# Override and set to true if the backend supports parallel edges
def multigraph(self):
return False


# Backends may wish to override these methods to implement them more efficiently

# These methods return mappings from vertices to various pieces of data. If the backend
# stores these e.g. as Python dicts, just return the relevant dicts.
def phases(self) -> Mapping[VT, FractionLike]:
"""Returns a mapping of vertices to their phase values."""
return { v: self.phase(v) for v in self.vertices() }

def types(self) -> Mapping[VT, VertexType]:
"""Returns a mapping of vertices to their types."""
return { v: self.type(v) for v in self.vertices() }

def qubits(self) -> Mapping[VT,FloatInt]:
"""Returns a mapping of vertices to their qubit index."""
return { v: self.qubit(v) for v in self.vertices() }

def rows(self) -> Mapping[VT, FloatInt]:
"""Returns a mapping of vertices to their row index."""
return { v: self.row(v) for v in self.vertices() }

def depth(self) -> FloatInt:
"""Returns the value of the highest row number given to a vertex.
This is -1 when no rows have been set."""
Expand All @@ -347,7 +299,18 @@ def depth(self) -> FloatInt:
else:
return max(self.row(v) for v in self.vertices())

def multigraph(self):
def edge(self, s:VT, t:VT, et: EdgeType=EdgeType.SIMPLE) -> ET:
"""Returns the name of the first edge with the given source/target and type. Behaviour is undefined if the vertices are not connected."""
for e in self.incident_edges(s):
if t in self.edge_st(e) and et == self.edge_type(e):
return e
raise ValueError(f"No edge of type {et} between {s} and {t}")

def connected(self,v1: VT,v2: VT) -> bool:
"""Returns whether vertices v1 and v2 share an edge."""
for e in self.incident_edges(v1):
if v2 in self.edge_st(e):
return True
return False

def add_vertex(self,
Expand Down Expand Up @@ -377,6 +340,11 @@ def add_vertex(self,
self.phase_mult[self.max_phase_index] = 1
return v

def add_edges(self, edge_pairs: Iterable[Tuple[VT,VT]], edgetype:EdgeType=EdgeType.SIMPLE) -> None:
"""Adds a list of edges to the graph."""
for ep in edge_pairs:
self.add_edge(ep, edgetype)

def remove_vertex(self, vertex: VT) -> None:
"""Removes the given vertex from the graph."""
self.remove_vertices([vertex])
Expand All @@ -389,8 +357,57 @@ def add_to_phase(self, vertex: VT, phase: FractionLike) -> None:
"""Add the given phase to the phase value of the given vertex."""
self.set_phase(vertex,self.phase(vertex)+phase)

def num_inputs(self) -> int:
"""Gets the number of inputs of the graph."""
return len(self.inputs())

def num_outputs(self) -> int:
"""Gets the number of outputs of the graph."""
return len(self.outputs())

def set_position(self, vertex: VT, q: FloatInt, r: FloatInt):
"""Set both the qubit index and row index of the vertex."""
self.set_qubit(vertex, q)
self.set_row(vertex, r)

def neighbors(self, vertex: VT) -> Sequence[VT]:
"""Returns all neighboring vertices of the given vertex."""
vs: Set[VT] = set()
for e in self.incident_edges(vertex):
s,t = self.edge_st(e)
vs.add(s if t == vertex else t)
return list(vs)

def vertex_degree(self, vertex: VT) -> int:
"""Returns the degree of the given vertex."""
return len(self.incident_edges(vertex))

def edge_s(self, edge: ET) -> VT:
"""Returns the source of the given edge."""
return self.edge_st(edge)[0]

def edge_t(self, edge: ET) -> VT:
"""Returns the target of the given edge."""
return self.edge_st(edge)[1]


def vertex_set(self) -> Set[VT]:
"""Returns the vertices of the graph as a Python set.
Should be overloaded if the backend supplies a cheaper version than this."""
return set(self.vertices())

def edge_set(self) -> Set[ET]:
"""Returns the edges of the graph as a Python set.
Should be overloaded if the backend supplies a cheaper version than this. Note this ignores parallel edges."""
return set(self.edges())


# }}}

# def vindex(self) -> VT:
# """The index given to the next vertex added to the graph. It should always
# be equal to ``max(g.vertices()) + 1``."""
# raise NotImplementedError("Not implemented on backend " + type(self).backend)

def __str__(self) -> str:
return "Graph({} vertices, {} edges)".format(
Expand Down Expand Up @@ -550,8 +567,6 @@ def compose(self, other: BaseGraph[VT,ET]) -> None:
self.add_edge((no,vtab[ni]), edgetype=et)
self.set_outputs(tuple(vtab[v] for v in other.outputs()))



def tensor(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
"""Take the tensor product of two graphs. Places the second graph below the first one.
Can also be called using the operator ``graph1 @ graph2``"""
Expand Down Expand Up @@ -763,8 +778,6 @@ def from_tikz(cls, tikz: str, warn_overlap:bool= True, fuse_overlap:bool = True,
from ..tikz import tikz_to_graph
return tikz_to_graph(tikz,warn_overlap, fuse_overlap, ignore_nonzx, cls.backend)



def is_id(self) -> bool:
"""Returns whether the graph is just a set of identity wires,
i.e. a graph where all the vertices are either inputs or outputs,
Expand Down

0 comments on commit 59cccde

Please sign in to comment.