Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate graph objects into hdbscan #539

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions examples/plot_hdbscan_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import numpy
import numpy as np
import sklearn.metrics
from scipy import sparse
import igraph
import networkx as nx
import time
from hdbscan import HDBSCAN


def create_distance_matrix(graph):
"""
Creates a distance matrix from the given graph using the igraph shortest path algorithm.
:param graph: An igraph graph object.
:return: Scipy csr matrix based on the graph.
"""

# create variables
path_weight, vertex_from_list, vertex_to_list, vertex_from = [], [], [], 0

# create a distance matrix based of the graph
for vertex in graph.vs:
list_edges_shortest_path = graph.get_shortest_paths(vertex, to=None, weights="weight", mode='out',
output="epath")
vertex_to = 0

for edge_list in list_edges_shortest_path:
if edge_list:
vertex_from_list.append(vertex_from)
vertex_to_list.append(vertex_to)
path_weight.append(sum(graph.es.select(edge_list)["weight"]))
else:
vertex_from_list.append(vertex_from)
vertex_to_list.append(vertex_to)
path_weight.append(0)

vertex_to += 1
vertex_from += 1

# transform lists into a csr matrix
distance_matrix = sparse.csr_matrix((path_weight, (vertex_from_list, vertex_to_list)))

return distance_matrix


def hdbscan_graph():
"""
Creates a weighted stochastic_block_model graph to compare the newly created graph function of HDBSCAN
to the precomputed metric using a distance matrix created for the graph.
"""
# measure time
start_build_graph = time.time()

# set parameters graph and edges
number_communities = np.random.randint(3, 20, 1)[0]
edge_weight_in_community = 0.1
edge_weight_out_community = 1

# create graph
community_sizes = np.random.randint(low=30, high=70, size=number_communities)
matrix_prob = np.random.rand(number_communities, number_communities)
matrix_prob = (np.tril(matrix_prob) + np.tril(matrix_prob, -1).T) * 0.5
numpy.fill_diagonal(matrix_prob, 0.7)
sbm_graph = nx.stochastic_block_model(community_sizes, matrix_prob, seed=0)

# convert to igraph object
graph = igraph.Graph(n=sbm_graph.number_of_nodes(), directed=False)
graph.add_edges(sbm_graph.edges())

# check for double edges and loops and delete those
graph.simplify()
graph.vs.select(_degree=0).delete()

# run community detection to assign edge weights, the function won't works on unweighted graphs
community_detection = graph.community_multilevel()

# add edge weights
weight_list = []
for edge in graph.es:
vertex_1 = edge.source
vertex_2 = edge.target
edge_weight_added = False
for subgraph in community_detection:
if vertex_1 in subgraph and vertex_2 in subgraph:
weight_list.append(edge_weight_in_community)
edge_weight_added = True
if not edge_weight_added:
weight_list.append(edge_weight_out_community)

graph.es["weight"] = weight_list

print("Graph created:", time.time() - start_build_graph)

# run HDBSCAN on graph distance matrix
start_distance_matrix = time.time()

# create a distance matrix from the graph
distance_matrix = create_distance_matrix(graph)

# run HDBSCAN on the created distance matrix
clusterer = HDBSCAN(metric="precomputed").fit(distance_matrix)
labels_distance_matrix = clusterer.labels_

# measure time
print("HDBSCAN distance matrix:", time.time() - start_distance_matrix)

# plot graph clustering using iGraph
graph.vs["label_distance_matrix"] = labels_distance_matrix
vclustering = igraph.clustering.VertexClustering.FromAttribute(graph, "label_distance_matrix")
igraph.plot(vclustering)

"""
Convert the iGraph graph into a csr sparse matrix, which the modified HDBSCAN function accepts and
transforms into a scipy csgraph.
"""
# run HDBSCAN using the graph metric
start_hdbscan_graph = time.time()

# create adjacency matrix from the graph, csr sparse matrix format
adjacency = graph.get_adjacency_sparse(attribute="weight")

clusterer = HDBSCAN(metric="graph").fit(adjacency)
labels_hdbscan_graph = clusterer.labels_

print("HDBSCAN graph:", time.time() - start_hdbscan_graph)

# plot clustering labels using iGraph
graph.vs["label_hdbscan_graph"] = labels_hdbscan_graph
vclustering = igraph.clustering.VertexClustering.FromAttribute(graph, "label_hdbscan_graph")
igraph.plot(vclustering)

# print the AMI and ARI for the labels
print("AMI:", sklearn.metrics.adjusted_mutual_info_score(labels_distance_matrix, labels_hdbscan_graph))
print("ARI:", sklearn.metrics.adjusted_rand_score(labels_distance_matrix, labels_hdbscan_graph))


"""
run the example function displaying the graph feature of HDBSCAN
"""
hdbscan_graph()
102 changes: 70 additions & 32 deletions hdbscan/hdbscan_.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ def _hdbscan_generic(
# matrix to indicate missing distance information.
# TODO: Check if copying is necessary
distance_matrix = X.copy()
elif metric == "graph":
# takes the graph csr matrix and converts it directly into a min_span_tree

# X should be the adjacency of the graph in csr sparse format
adjacency_matrix = X

# run the distance matrix function with metric "graph" creating a cs min spanning tree
return _hdbscan_sparse_distance_matrix(
adjacency_matrix,
min_samples,
alpha,
"graph",
p,
leaf_size,
gen_min_span_tree,
**kwargs
)

JanRhoKa marked this conversation as resolved.
Show resolved Hide resolved
else:
distance_matrix = pairwise_distances(X, metric=metric, **kwargs)

Expand Down Expand Up @@ -162,42 +180,62 @@ def _hdbscan_sparse_distance_matrix(
**kwargs
):
assert issparse(X)
# Check for connected component on X
if csgraph.connected_components(X, directed=False, return_labels=False) > 1:
raise ValueError(
"Sparse distance matrix has multiple connected "
"components!\nThat is, there exist groups of points "
"that are completely disjoint -- there are no distance "
"relations connecting them\n"
"Run hdbscan on each component."
)

lil_matrix = X.tolil()
# if the metric is not graph, build a min spanning tree from the sparse matrix
if metric != "graph":
# Check for connected component on X
if csgraph.connected_components(X, directed=False, return_labels=False) > 1:
raise ValueError(
"Sparse distance matrix has multiple connected "
"components!\nThat is, there exist groups of points "
"that are completely disjoint -- there are no distance "
"relations connecting them\n"
"Run hdbscan on each component."
)
JanRhoKa marked this conversation as resolved.
Show resolved Hide resolved

# Compute sparse mutual reachability graph
# if max_dist > 0, max distance to use when the reachability is infinite
max_dist = kwargs.get("max_dist", 0.0)
mutual_reachability_ = sparse_mutual_reachability(
lil_matrix, min_points=min_samples, max_dist=max_dist, alpha=alpha
)
# Check connected component on mutual reachability
# If more than one component, it means that even if the distance matrix X
# has one component, there exists with less than `min_samples` neighbors
if (
csgraph.connected_components(
mutual_reachability_, directed=False, return_labels=False
lil_matrix = X.tolil()

# Compute sparse mutual reachability graph
# if max_dist > 0, max distance to use when the reachability is infinite
max_dist = kwargs.get("max_dist", 0.0)
mutual_reachability_ = sparse_mutual_reachability(
lil_matrix, min_points=min_samples, max_dist=max_dist, alpha=alpha
)
> 1
):
raise ValueError(
(
"There exists points with less than %s neighbors. "
"Ensure your distance matrix has non zeros values for "
"at least `min_sample`=%s neighbors for each points (i.e. K-nn graph), "
"or specify a `max_dist` to use when distances are missing."
# Check connected component on mutual reachability
# If more than one component, it means that even if the distance matrix X
# has one component, there exists with less than `min_samples` neighbors
if (
csgraph.connected_components(
mutual_reachability_, directed=False, return_labels=False
)
> 1
):
raise ValueError(
(
"There exists points with less than %s neighbors. "
"Ensure your distance matrix has non zeros values for "
"at least `min_sample`=%s neighbors for each points (i.e. K-nn graph), "
"or specify a `max_dist` to use when distances are missing."
)
% (min_samples, min_samples)
)
% (min_samples, min_samples)
)

# otherwise convert the csr adjacency matrix from the graph into a minimum spanning tree
else:
# check components of the graph
if (
csgraph.connected_components(X)[0]
> 1
):
raise ValueError(
(
"The passed graph has more than on component. \n"
"Run hdbscan on each component."
)
)
# if one component set the mutual_reachability_ to the csr from the graph
else:
mutual_reachability_ = X

# Compute the minimum spanning tree for the sparse graph
sparse_min_spanning_tree = csgraph.minimum_spanning_tree(mutual_reachability_)
Expand Down
44 changes: 34 additions & 10 deletions hdbscan/tests/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Shamelessly based on (i.e. ripped off from) the DBSCAN test code
"""
import numpy as np
import networkx as nx
import sklearn.metrics
from scipy.spatial import distance
from scipy import sparse
from scipy import stats
Expand Down Expand Up @@ -251,11 +253,13 @@ def test_hdbscan_generic():
n_clusters_2 = len(set(labels)) - int(-1 in labels)
assert n_clusters_2 == n_clusters


def test_hdbscan_dbscan_clustering():
clusterer = HDBSCAN().fit(X)
labels = clusterer.dbscan_clustering(0.3)
n_clusters_1 = len(set(labels)) - int(-1 in labels)
assert(n_clusters == n_clusters_1)
assert (n_clusters == n_clusters_1)


def test_hdbscan_high_dimensional():
H, y = make_blobs(n_samples=50, random_state=0, n_features=64)
Expand All @@ -267,8 +271,8 @@ def test_hdbscan_high_dimensional():

labels = (
HDBSCAN(algorithm="best", metric="seuclidean", V=np.ones(H.shape[1]))
.fit(H)
.labels_
.fit(H)
.labels_
)
n_clusters_2 = len(set(labels)) - int(-1 in labels)
assert n_clusters_2 == n_clusters
Expand Down Expand Up @@ -330,7 +334,6 @@ def test_hdbscan_input_lists():


def test_hdbscan_boruvka_kdtree_matches():

data = generate_noisy_data()

labels_prims, p, persist, ctree, ltree, mtree = hdbscan(data, algorithm="generic")
Expand All @@ -351,7 +354,6 @@ def test_hdbscan_boruvka_kdtree_matches():


def test_hdbscan_boruvka_balltree_matches():

data = generate_noisy_data()

labels_prims, p, persist, ctree, ltree, mtree = hdbscan(data, algorithm="generic")
Expand Down Expand Up @@ -414,7 +416,6 @@ def test_min_span_tree_plot():


def test_tree_numpy_output_formats():

clusterer = HDBSCAN(gen_min_span_tree=True).fit(X)

clusterer.single_linkage_tree_.to_numpy()
Expand All @@ -423,15 +424,13 @@ def test_tree_numpy_output_formats():


def test_tree_pandas_output_formats():

clusterer = HDBSCAN(gen_min_span_tree=True).fit(X)
if_pandas(clusterer.condensed_tree_.to_pandas)()
if_pandas(clusterer.single_linkage_tree_.to_pandas)()
if_pandas(clusterer.minimum_spanning_tree_.to_pandas)()


def test_tree_networkx_output_formats():

clusterer = HDBSCAN(gen_min_span_tree=True).fit(X)
if_networkx(clusterer.condensed_tree_.to_networkx)()
if_networkx(clusterer.single_linkage_tree_.to_networkx)()
Expand Down Expand Up @@ -576,7 +575,6 @@ def test_hdbscan_badargs():


def test_hdbscan_sparse():

sparse_X = sparse.csr_matrix(X)

labels = HDBSCAN().fit(sparse_X).labels_
Expand All @@ -585,7 +583,6 @@ def test_hdbscan_sparse():


def test_hdbscan_caching():

cachedir = mkdtemp()
labels1 = HDBSCAN(memory=cachedir, min_samples=5).fit(X).labels_
labels2 = HDBSCAN(memory=cachedir, min_samples=5, min_cluster_size=6).fit(X).labels_
Expand Down Expand Up @@ -646,6 +643,33 @@ def test_hdbscan_is_sklearn_estimator():
check_estimator(HDBSCAN)


def test_hdbscan_graph():
# create a distance matrix, see test_hdbscan_distance_matrix
D = distance.squareform(distance.pdist(X))
D /= np.max(D)

threshold = stats.scoreatpercentile(D.flatten(), 50)

D[D >= threshold] = 0.0
D = sparse.csr_matrix(D)
D.eliminate_zeros()

# create cluster labels using precomputed metric
clusterer = HDBSCAN(metric="precomputed").fit(D)
labels_distance_matrix = clusterer.labels_

# create a graph from the distance matrix and transform the graph to a csr adjacency matrix
graph = nx.from_numpy_matrix(D.toarray())
adjacency_matrix = nx.adjacency_matrix(graph)

# create cluster labels using the graph metric
clusterer = HDBSCAN(metric="graph").fit(adjacency_matrix)
labels_hdbscan_graph = clusterer.labels_

assert sklearn.metrics.accuracy_score(labels_distance_matrix, labels_hdbscan_graph) == 1
JanRhoKa marked this conversation as resolved.
Show resolved Hide resolved



# Probably not applicable now #
# def test_dbscan_sparse():
# def test_dbscan_balltree():
Expand Down
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,11 @@ numpy>=1.20
scipy>= 1.0
scikit-learn>=0.20
joblib>=1.0

pytest~=7.1.1
hdbscan~=0.8.28
networkx~=2.8
matplotlib~=3.5.1
igraph~=0.9.9
pycairo~=1.21.0
setuptools~=61.2.0