diff --git a/src/boruvkas_algorithm/boruvka.py b/src/boruvkas_algorithm/boruvka.py index febfa57..da54705 100644 --- a/src/boruvkas_algorithm/boruvka.py +++ b/src/boruvkas_algorithm/boruvka.py @@ -1,6 +1,8 @@ """ Implement Boruvka's algorithm for finding the minimum spanning tree of a graph. """ +import matplotlib.pyplot as plt +import networkx as nx from typing import Dict, List, Optional, Tuple @@ -217,6 +219,38 @@ def perform_iteration( return mst_weight, num_components + def draw_mst(self, mst_edges: List[Tuple[int, int, int]]) -> None: + """ + Draw the graph with the minimum spanning tree highlighted using + networkx. + + Args: + mst_edges: A list of edges in the minimum spanning tree. + """ + G = nx.Graph() + # Add nodes to the graph. + G.add_nodes_from(self.vertices) + # Add all edges to the graph with weights. + for edge in self.edges: + vertex1, vertex2, weight = edge + G.add_edge(vertex1, vertex2, weight=weight) + pos = nx.spring_layout(G) + # Draw the graph edges and highlight the edges in the MST in red. + nx.draw_networkx_edges( + G, pos, edgelist=self.edges, edge_color="gray", alpha=0.5 + ) + nx.draw_networkx_edges(G, pos, edgelist=mst_edges, edge_color="red", width=2) + # Draw the graph nodes and labels. + nx.draw_networkx_nodes(G, pos, node_size=700, node_color="lightblue") + nx.draw_networkx_labels(G, pos) + nx.draw_networkx_edge_labels( + G, pos, edge_labels={(u, v): d["weight"] for u, v, d in G.edges(data=True)} + ) + + plt.title("Graph with Minimum Spanning Tree Highlighted") + plt.axis("off") + plt.show() + def run_boruvkas_algorithm(self): """ Find the minimum spanning tree (MST) of the graph using Boruvka's @@ -254,12 +288,13 @@ def run_boruvkas_algorithm(self): mst_weight, ) - # Summarise the MST found. + # Summarise the MST found and draw it. print("\nMST found with Boruvka's algorithm.") print("MST edges (vertex_1, vertex_2, weight):") for edge in sorted(mst_edges): print(f" {edge}") print(f"MST weight: {mst_weight}") + self.draw_mst(mst_edges) return mst_weight, mst_edges