Skip to content

Commit

Permalink
Draw graph with networkx
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacCheng9 committed Feb 4, 2024
1 parent 7d80ee0 commit 24b30b2
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/boruvkas_algorithm/boruvka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Implement Boruvka's algorithm for finding the minimum spanning tree of a graph.
"""
import matplotlib.pyplot as plt
import networkx as nx

from typing import Dict, List, Optional, Tuple

Expand Down Expand Up @@ -217,6 +219,38 @@ def perform_iteration(

return mst_weight, num_components

def draw_mst(self, mst_edges: List[Tuple[int, int, int]]) -> None:
"""
Draw the graph with the minimum spanning tree highlighted using
networkx.
Args:
mst_edges: A list of edges in the minimum spanning tree.
"""
G = nx.Graph()
# Add nodes to the graph.
G.add_nodes_from(self.vertices)
# Add all edges to the graph with weights.
for edge in self.edges:
vertex1, vertex2, weight = edge
G.add_edge(vertex1, vertex2, weight=weight)
pos = nx.spring_layout(G)
# Draw the graph edges and highlight the edges in the MST in red.
nx.draw_networkx_edges(
G, pos, edgelist=self.edges, edge_color="gray", alpha=0.5
)
nx.draw_networkx_edges(G, pos, edgelist=mst_edges, edge_color="red", width=2)
# Draw the graph nodes and labels.
nx.draw_networkx_nodes(G, pos, node_size=700, node_color="lightblue")
nx.draw_networkx_labels(G, pos)
nx.draw_networkx_edge_labels(
G, pos, edge_labels={(u, v): d["weight"] for u, v, d in G.edges(data=True)}
)

plt.title("Graph with Minimum Spanning Tree Highlighted")
plt.axis("off")
plt.show()

def run_boruvkas_algorithm(self):
"""
Find the minimum spanning tree (MST) of the graph using Boruvka's
Expand Down Expand Up @@ -254,12 +288,13 @@ def run_boruvkas_algorithm(self):
mst_weight,
)

# Summarise the MST found.
# Summarise the MST found and draw it.
print("\nMST found with Boruvka's algorithm.")
print("MST edges (vertex_1, vertex_2, weight):")
for edge in sorted(mst_edges):
print(f" {edge}")
print(f"MST weight: {mst_weight}")
self.draw_mst(mst_edges)

return mst_weight, mst_edges

Expand Down

0 comments on commit 24b30b2

Please sign in to comment.