From 3586fb33d105340df7f243dd53ec42eadba892e6 Mon Sep 17 00:00:00 2001 From: markopy <48253511+markopy@users.noreply.github.com> Date: Sat, 8 Feb 2020 22:05:58 +0000 Subject: [PATCH 1/4] Faster version of sparse_mutual_reachability --- hdbscan/_hdbscan_reachability.pyx | 120 ++++++++++++++++++++---------- hdbscan/hdbscan_.py | 8 +- hdbscan/tests/test_hdbscan.py | 6 ++ 3 files changed, 91 insertions(+), 43 deletions(-) diff --git a/hdbscan/_hdbscan_reachability.pyx b/hdbscan/_hdbscan_reachability.pyx index df0f7b8b..9c7ae8cc 100644 --- a/hdbscan/_hdbscan_reachability.pyx +++ b/hdbscan/_hdbscan_reachability.pyx @@ -1,15 +1,16 @@ # cython: boundscheck=False +# cython: wraparound=False # cython: nonecheck=False # cython: initializedcheck=False -# mutual reachability distance compiutations +# mutual reachability distance computations # Authors: Leland McInnes # License: 3-clause BSD import numpy as np cimport numpy as np +from numpy.math cimport INFINITY, isfinite from scipy.spatial.distance import pdist, squareform -from scipy.sparse import lil_matrix as sparse_matrix from sklearn.neighbors import KDTree, BallTree import gc @@ -59,44 +60,87 @@ def mutual_reachability(distance_matrix, min_points=5, alpha=1.0): return result -cpdef sparse_mutual_reachability(object lil_matrix, np.intp_t min_points=5, - float alpha=1.0, float max_dist=0.): +cpdef sparse_mutual_reachability(object distance_matrix, np.intp_t min_points=5, + float alpha=1.0, float max_dist=0.0): + """ Compute mutual reachability for distance matrix. For best performance + pass distance_matrix in CSR form which will modify it in place and return + it without making a copy """ - cdef np.intp_t i - cdef np.intp_t j - cdef np.intp_t n - cdef np.double_t mr_dist - cdef list sorted_row_data - cdef np.ndarray[dtype=np.double_t, ndim=1] core_distance - cdef np.ndarray[dtype=np.int32_t, ndim=1] nz_row_data - cdef np.ndarray[dtype=np.int32_t, ndim=1] nz_col_data - - result = sparse_matrix(lil_matrix.shape) - core_distance = np.empty(lil_matrix.shape[0], dtype=np.double) - - for i in range(lil_matrix.shape[0]): - sorted_row_data = sorted(lil_matrix.data[i]) - if min_points < len(sorted_row_data): - core_distance[i] = sorted_row_data[min_points] + # tocsr() is a fast noop if distance_matrix is already CSR + D = distance_matrix.tocsr() + # Convert to smallest supported data type if necessary + if D.dtype not in (np.float32, np.float64): + if D.dtype <= np.dtype(np.float32): + D = D.astype(np.float32) else: - core_distance[i] = np.infty - - if alpha != 1.0: - lil_matrix = lil_matrix / alpha - - nz_row_data, nz_col_data = lil_matrix.nonzero() - - for n in range(nz_row_data.shape[0]): - i = nz_row_data[n] - j = nz_col_data[n] - - mr_dist = max(core_distance[i], core_distance[j], lil_matrix[i, j]) - if np.isfinite(mr_dist): - result[i, j] = mr_dist - elif max_dist > 0: - result[i, j] = max_dist - - return result.tocsr() + D = D.astype(np.float64) + + # Call typed function which modifies D in place + t = (D.data.dtype, D.indices.dtype, D.indptr.dtype) + if t == (np.float32, np.int32, np.int32): + sparse_mr_fast[np.float32_t, np.int32_t](D.data, D.indices, D.indptr, + min_points, alpha, max_dist) + elif t == (np.float32, np.int64, np.int64): + sparse_mr_fast[np.float32_t, np.int64_t](D.data, D.indices, D.indptr, + min_points, alpha, max_dist) + elif t == (np.float64, np.int32, np.int32): + sparse_mr_fast[np.float64_t, np.int32_t](D.data, D.indices, D.indptr, + min_points, alpha, max_dist) + elif t == (np.float64, np.int64, np.int64): + sparse_mr_fast[np.float64_t, np.int64_t](D.data, D.indices, D.indptr, + min_points, alpha, max_dist) + else: + raise Exception("Unsupported CSR format {}".format(t)) + + return D + + +ctypedef fused mr_indx_t: + np.int32_t + np.int64_t + +ctypedef fused mr_data_t: + np.float32_t + np.float64_t + +cdef sparse_mr_fast(np.ndarray[dtype=mr_data_t, ndim=1] data, + np.ndarray[dtype=mr_indx_t, ndim=1] indices, + np.ndarray[dtype=mr_indx_t, ndim=1] indptr, + mr_indx_t min_points, + mr_data_t alpha, + mr_data_t max_dist): + cdef mr_indx_t row, col, colptr + cdef mr_data_t mr_dist + cdef np.ndarray[dtype=mr_data_t, ndim=1] row_data + cdef np.ndarray[dtype=mr_data_t, ndim=1] core_distance + + core_distance = np.empty(data.shape[0], dtype=data.dtype) + + for row in range(indptr.shape[0]-1): + row_data = data[indptr[row]:indptr[row+1]].copy() + if min_points < row_data.shape[0]: + # sort is faster for small arrays because of lower startup cost but + # partition has worst case O(n) runtime for larger ones. + # https://stackoverflow.com/questions/43588711/numpys-partition-slower-than-sort-for-small-arrays + if row_data.shape[0] > 200: + row_data.partition(min_points) + else: + row_data.sort() + core_distance[row] = row_data[min_points] + else: + core_distance[row] = INFINITY + + if alpha != 1.0: + data /= alpha + + for row in range(indptr.shape[0]-1): + for colptr in range(indptr[row],indptr[row+1]): + col = indices[colptr] + mr_dist = max(core_distance[row], core_distance[col], data[colptr]) + if isfinite(mr_dist): + data[colptr] = mr_dist + elif max_dist > 0: + data[colptr] = max_dist def kdtree_mutual_reachability(X, distance_matrix, metric, p=2, min_points=5, diff --git a/hdbscan/hdbscan_.py b/hdbscan/hdbscan_.py index 31b7f84a..61e05279 100644 --- a/hdbscan/hdbscan_.py +++ b/hdbscan/hdbscan_.py @@ -78,13 +78,12 @@ def _hdbscan_generic(X, min_samples=5, alpha=1.0, metric='minkowski', p=2, # sklearn.metrics.pairwise_distances handle it, # enables the usage of numpy.inf in the distance # matrix to indicate missing distance information. - # TODO: Check if copying is necessary + # Need copy because distance_matrix may be modified if sparse distance_matrix = X.copy() else: distance_matrix = pairwise_distances(X, metric=metric, **kwargs) if issparse(distance_matrix): - # raise TypeError('Sparse distance matrices not yet supported') return _hdbscan_sparse_distance_matrix(distance_matrix, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, @@ -141,12 +140,11 @@ def _hdbscan_sparse_distance_matrix(X, min_samples=5, alpha=1.0, 'relations connecting them\n' 'Run hdbscan on each component.') - lil_matrix = X.tolil() - # Compute sparse mutual reachability graph # if max_dist > 0, max distance to use when the reachability is infinite max_dist = kwargs.get("max_dist", 0.) - mutual_reachability_ = sparse_mutual_reachability(lil_matrix, + # sparse_mutual_reachability() may modify X in place and return it + mutual_reachability_ = sparse_mutual_reachability(X, min_points=min_samples, max_dist=max_dist, alpha=alpha) diff --git a/hdbscan/tests/test_hdbscan.py b/hdbscan/tests/test_hdbscan.py index c2d5201d..c9fc1f73 100644 --- a/hdbscan/tests/test_hdbscan.py +++ b/hdbscan/tests/test_hdbscan.py @@ -154,6 +154,12 @@ def test_hdbscan_sparse_distance_matrix(): n_clusters_2 = len(set(labels)) - int(-1 in labels) assert_equal(n_clusters_2, n_clusters) + # Verify single and double precision return same results + assert_equal(D.dtype, np.double) + labels_double = hdbscan(D, metric='precomputed')[0] + labels_single = hdbscan(D.astype(np.single), metric='precomputed')[0] + assert_array_equal(labels_single, labels_double) + def test_hdbscan_feature_vector(): labels, p, persist, ctree, ltree, mtree = hdbscan(X) From e9f8dfb28d966beab8c1ae1097d6947f78463908 Mon Sep 17 00:00:00 2001 From: markopy <48253511+markopy@users.noreply.github.com> Date: Mon, 10 Feb 2020 00:46:12 +0000 Subject: [PATCH 2/4] More robust parameter checking for sparse reachability --- hdbscan/_hdbscan_reachability.pyx | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/hdbscan/_hdbscan_reachability.pyx b/hdbscan/_hdbscan_reachability.pyx index 9c7ae8cc..621bca64 100644 --- a/hdbscan/_hdbscan_reachability.pyx +++ b/hdbscan/_hdbscan_reachability.pyx @@ -68,6 +68,8 @@ cpdef sparse_mutual_reachability(object distance_matrix, np.intp_t min_points=5, # tocsr() is a fast noop if distance_matrix is already CSR D = distance_matrix.tocsr() + if D.shape != (D.shape[0], D.shape[0]): + raise Exception("Distance matrix must be 2D square shaped instead of {}".format(D.shape)) # Convert to smallest supported data type if necessary if D.dtype not in (np.float32, np.float64): if D.dtype <= np.dtype(np.float32): @@ -103,18 +105,18 @@ ctypedef fused mr_data_t: np.float32_t np.float64_t -cdef sparse_mr_fast(np.ndarray[dtype=mr_data_t, ndim=1] data, - np.ndarray[dtype=mr_indx_t, ndim=1] indices, - np.ndarray[dtype=mr_indx_t, ndim=1] indptr, +cdef sparse_mr_fast(np.ndarray[dtype=mr_data_t, ndim=1, mode='c'] data, + np.ndarray[dtype=mr_indx_t, ndim=1, mode='c'] indices, + np.ndarray[dtype=mr_indx_t, ndim=1, mode='c'] indptr, mr_indx_t min_points, mr_data_t alpha, mr_data_t max_dist): cdef mr_indx_t row, col, colptr cdef mr_data_t mr_dist - cdef np.ndarray[dtype=mr_data_t, ndim=1] row_data - cdef np.ndarray[dtype=mr_data_t, ndim=1] core_distance + cdef np.ndarray[dtype=mr_data_t, ndim=1, mode='c'] row_data + cdef np.ndarray[dtype=mr_data_t, ndim=1, mode='c'] core_distance - core_distance = np.empty(data.shape[0], dtype=data.dtype) + core_distance = np.empty(indptr.shape[0]-1, dtype=data.dtype) for row in range(indptr.shape[0]-1): row_data = data[indptr[row]:indptr[row+1]].copy() From b7bba04796f9c808d58101ab17f31f1114d2dd9e Mon Sep 17 00:00:00 2001 From: markopy <48253511+markopy@users.noreply.github.com> Date: Tue, 11 Feb 2020 16:35:20 +0000 Subject: [PATCH 3/4] Allow minimum_spanning_tree() to prune reachability matrix in place --- hdbscan/hdbscan_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hdbscan/hdbscan_.py b/hdbscan/hdbscan_.py index 61e05279..d31adf67 100644 --- a/hdbscan/hdbscan_.py +++ b/hdbscan/hdbscan_.py @@ -161,7 +161,7 @@ def _hdbscan_sparse_distance_matrix(X, min_samples=5, alpha=1.0, # Compute the minimum spanning tree for the sparse graph sparse_min_spanning_tree = csgraph.minimum_spanning_tree( - mutual_reachability_) + mutual_reachability_, overwrite=True) # Convert the graph to scipy cluster array format nonzeros = sparse_min_spanning_tree.nonzero() From edbc4c7ac3bcef437490842d695ba8e5ec0b6466 Mon Sep 17 00:00:00 2001 From: markopy <48253511+markopy@users.noreply.github.com> Date: Tue, 11 Feb 2020 17:11:20 +0000 Subject: [PATCH 4/4] Add option to save memory by letting distance matrix be overwritten --- hdbscan/hdbscan_.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hdbscan/hdbscan_.py b/hdbscan/hdbscan_.py index d31adf67..5c29e0e5 100644 --- a/hdbscan/hdbscan_.py +++ b/hdbscan/hdbscan_.py @@ -78,8 +78,12 @@ def _hdbscan_generic(X, min_samples=5, alpha=1.0, metric='minkowski', p=2, # sklearn.metrics.pairwise_distances handle it, # enables the usage of numpy.inf in the distance # matrix to indicate missing distance information. - # Need copy because distance_matrix may be modified if sparse - distance_matrix = X.copy() + # Give the user the option to have the distance matrix + # modified to save memory. + if kwargs.get("overwrite", False): + distance_matrix = X + else: + distance_matrix = X.copy() else: distance_matrix = pairwise_distances(X, metric=metric, **kwargs)