Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Graph Drawing for draw_d3 #277

Merged
merged 9 commits into from
Nov 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions pyzx/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def draw(g: Union[BaseGraph[VT,ET], Circuit], labels: bool=False, **kwargs) -> A
# allow global setting to labels=False
# TODO: probably better to make labels Optional[bool]
labels = labels or settings.show_labels

if get_mode() == "shell":
return draw_matplotlib(g, labels, **kwargs)
elif get_mode() == "browser":
Expand Down Expand Up @@ -288,15 +287,51 @@ def draw_matplotlib(
# library_code += '</script>'
# display(HTML(library_code))

def auto_layout_vertex_locs(g:BaseGraph[VT, ET]): #Force-based graph drawing algorithm given by Eades(1984):
c1 = 2 #Sample parameters that work decently well
c2 = 1
c3 = 1
c4 = .1
v_locs:Dict[VT, Tuple[float, float]] = dict()
for v in g.vertices():
v_locs[v]=(random.random()*math.sqrt(g.num_vertices()), random.random()*math.sqrt(g.num_vertices()))
for i in range(100): #100 iterations of force-based drawing
forces:Dict[VT, Tuple[float, float]] = dict()
for v in g.vertices():
forces[v] = (0, 0)
for v1 in g.vertices():
if(v!=v1):
diff = (v_locs[v][0]-v_locs[v1][0], v_locs[v][1]-v_locs[v1][1])
d = math.sqrt(diff[0]*diff[0]+diff[1]*diff[1])
if g.connected(v1, v): #edge between vertices: apply rule c1*log(d/c2)
force_mag = -c1*math.log(d/c2) #negative force attracts
elif v != v1: #nonadjacent vertices: apply rule -c3/d^2
force_mag = c3/(d*d) #positive force repels
else: #free body in question, applies no force on itself
raise ValueError("Vertices ended up at same point")
v_force = (diff[0]*force_mag*c4/d, diff[1]*force_mag*c4/d)
forces[v] = (forces[v][0]+v_force[0], forces[v][1]+v_force[1])
for v in g.vertices(): #leave y value constant if input or output
v_locs[v]=(v_locs[v][0]+forces[v][0], v_locs[v][1]+forces[v][1])
max_x = max(v[0] for v in v_locs.values())
min_x = min(v[0] for v in v_locs.values())
max_y = max(v[1] for v in v_locs.values())
min_y = min(v[1] for v in v_locs.values())
v_locs = {k:(v[0]-min_x, v[1]-min_y) for k, v in v_locs.items()} #translate to origin
return v_locs, max_x-min_x, max_y-min_y


def draw_d3(
g: Union[BaseGraph[VT,ET], Circuit],
labels:bool=False,
scale:Optional[FloatInt]=None,
auto_hbox:Optional[bool]=None,
show_scalar:bool=False,
vdata: List[str]=[]
vdata: List[str]=[],
auto_layout = False
) -> Any:

"""If auto_layout is checked, will automatically space vertices of graph
with no regard to qubit/row."""
if get_mode() not in ("notebook", "browser"):
raise Exception("This method only works when loaded in a webpage or Jupyter notebook")

Expand All @@ -310,25 +345,35 @@ def draw_d3(
# use an 8-digit random alphanum instead.
graph_id = ''.join(random_graphid.choice(string.ascii_letters + string.digits) for _ in range(8))

minrow = min([g.row(v) for v in g.vertices()], default=0)
maxrow = max([g.row(v) for v in g.vertices()], default=0)
minqub = min([g.qubit(v) for v in g.vertices()], default=0)
maxqub = max([g.qubit(v) for v in g.vertices()], default=0)
if(auto_layout):
v_dict, w, h = auto_layout_vertex_locs(g)
if scale is None:
scale = 800 / w
if scale > 50: scale = 50
if scale < 20: scale = 20

w = (w+2) * scale
h = (h+3) * scale
else:
minrow = min([g.row(v) for v in g.vertices()], default=0)
maxrow = max([g.row(v) for v in g.vertices()], default=0)
minqub = min([g.qubit(v) for v in g.vertices()], default=0)
maxqub = max([g.qubit(v) for v in g.vertices()], default=0)

if scale is None:
scale = 800 / (maxrow-minrow + 2)
if scale > 50: scale = 50
if scale < 20: scale = 20

if scale is None:
scale = 800 / (maxrow-minrow + 2)
if scale > 50: scale = 50
if scale < 20: scale = 20
w = (maxrow-minrow + 2) * scale
h = (maxqub-minqub + 3) * scale

node_size = 0.2 * scale
if node_size < 2: node_size = 2

w = (maxrow-minrow + 2) * scale
h = (maxqub-minqub + 3) * scale

nodes = [{'name': str(v),
'x': (g.row(v)-minrow + 1) * scale,
'y': (g.qubit(v)-minqub + 2) * scale,
'x': (v_dict[v][0]+1)*scale if auto_layout else (g.row(v)-minrow + 1) * scale,
'y': (v_dict[v][1]+2)*scale if auto_layout else (g.qubit(v)-minqub + 2) * scale,
't': g.type(v),
'phase': phase_to_s(g.phase(v), g.type(v)) if g.type(v) != VertexType.Z_BOX else str(get_z_box_label(g, v)),
'ground': g.is_ground(v),
Expand Down
Loading