-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_condensed_tree.py
93 lines (72 loc) · 2.65 KB
/
plot_condensed_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import sys
import numpy
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import networkx as nx
from networkx.drawing.nx_agraph import write_dot, graphviz_layout
def read_dataset(path):
return pd.read_csv(path)
def create_graph(df, labels=False):
DG = nx.DiGraph()
sizes = []
colors = []
selected_nodes = []
unselected_nodes = []
selected_sizes = []
unselected_sizes = []
labels = {}
for index, row in df.iterrows():
if row["child1_id"] != "0":
DG.add_edge(row["id"], row["child1_id"])
DG.add_edge(row["id"], row["child2_id"])
sizes.append(row["weight"])
if labels:
labels[row["id"]] = "{}:{}".format(index+1, row["size"])
# colors.append('#1f78b4' if row["selected"]==0 else '#f54823')
if row['selected'] == 1:
selected_nodes.append(row["id"])
selected_sizes.append(row["weight"])
else:
unselected_nodes.append(row["id"])
unselected_sizes.append(row["weight"])
# print(DG.edges)
# print(nx.is_tree(DG))
if not nx.is_tree(DG):
print("Generated tree is not actually a tree. Something went wrong!")
# return
# Rescale sizes
selected_sizes = [(20/(len(selected_sizes)+len(unselected_sizes))) * s for s in selected_sizes]
unselected_sizes = [(20/(len(selected_sizes)+len(unselected_sizes))) * s for s in unselected_sizes]
# same layout using matplotlib with no labels
plt.subplot(121)
plt.title("Condensed_tree")
pos = graphviz_layout(DG, prog="dot")
nx.draw(DG, pos, with_labels=False, arrows=True, node_size=sizes,
node_color=colors)
nodes = nx.draw_networkx_nodes(DG, pos, nodelist=unselected_nodes, edgecolors='#000000', node_size=unselected_sizes, node_color=unselected_sizes, cmap=plt.cm.viridis)
nodes = nx.draw_networkx_nodes(DG, pos, nodelist=selected_nodes, edgecolors='#ff0000', linewidths=2, node_size=selected_sizes, node_color=selected_sizes, cmap=plt.cm.viridis)
edges = nx.draw_networkx_edges(
DG,
pos,
node_size=sizes,
arrowstyle="->",
arrowsize=3,
edge_cmap=plt.cm.Blues,
width=1,
)
if labels:
pos_left = {}
x_off = 100 # offset on the y axis
for k, v in pos.items():
pos_left[k] = (v[0]+x_off, v[1])
nx.draw_networkx_labels(DG, pos_left, labels, font_size=10)
plt.show()
def main():
if len(sys.argv) != 2:
print('Usage: plot_clusters.py <tree_path>')
return
df = read_dataset(sys.argv[1])
create_graph(df)
if __name__ == "__main__":
main()