-
Notifications
You must be signed in to change notification settings - Fork 1
/
clustering.py
67 lines (50 loc) · 2.01 KB
/
clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Python implementation of clustering based on union find structure."""
from disjoint_set import DisjointSet
def k_clusters(items, k):
"""Groups items in k clusters. An item should be a collection of
2 points or points with distance between them.
:param items: sorted collection of (node1, node2, distance) objects.
:param k: int, number of clusters we want to have returned.
:return: {cluster1: [node1, node2], cluster2: [node3] ...}
"""
union_find = DisjointSet()
# Initial setting of all points, where each node belongs to its own cluster.
for item in items:
union_find.make_set(item[0])
union_find.make_set(item[1])
# After every union operation the clusters counter is decreased by 1.
counter = len(union_find)
for item in items:
node1, node2 = item[:2]
if union_find.is_connected(node1, node2):
continue
if counter == k:
return union_find.get_clusters()
union_find.union(node1, node2)
counter -= 1
if __name__ == "__main__":
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations
np.random.seed(5)
# Generate some points
xs = np.random.random_integers(1, 100, 40)
ys = np.random.random_integers(1, 100, 40)
points = {idx: coords for idx, coords in enumerate(zip(xs, ys))}
# Create edges with distances. Input for k_clusters function.
edges = []
for edge in combinations(points, 2):
x1, y1 = points[edge[0]]
x2, y2 = points[edge[1]]
dist = (abs(x1 - x2) + abs(y1 - y2)) ** 0.5
edges.append((edge[0], edge[1], dist))
edges.sort(key=lambda x: x[2])
# Clustering
clusters = k_clusters(edges, 7)
# Plot points. Clusters distinguished by colors.
colors = 'rgbcmyk'
for idx, cluster in enumerate(clusters):
coords = [points[node] for node in clusters[cluster]]
x_coords, y_coords = zip(*coords)
plt.scatter(x_coords, y_coords, c=colors[idx])
plt.show()