diff --git a/src/lineagetree/measure/spatial.py b/src/lineagetree/measure/spatial.py index 417ac05..52406c9 100644 --- a/src/lineagetree/measure/spatial.py +++ b/src/lineagetree/measure/spatial.py @@ -172,27 +172,36 @@ def compute_k_nearest_neighbours( dict mapping int to set of int dictionary that maps a node id to its `k` nearest neighbors + dict mapping int to set of float + dictionary that maps + a node id to the distances of its `k` nearest neighbors """ lT.kn_graph = {} - for t in set(lT._time.values()): - nodes = lT.time_nodes[t] + lT.kn_distances = {} + k = k + 1 + for t, nodes in lT.time_nodes.items(): if 1 < len(nodes): use_k = k if k < len(nodes) else len(nodes) idx3d, nodes = lT.get_idx3d(t) pos = [lT.pos[c] for c in nodes] - _, neighbs = idx3d.query(pos, use_k) + distances, neighbs = idx3d.query(pos, use_k) out = dict( zip( nodes, - map(set, nodes[neighbs]), + nodes[neighbs[:, 1:]], + strict=True, + ) + ) + out_distances = dict( + zip( + nodes, + distances[:, 1:], strict=True, ) ) lT.kn_graph.update(out) - else: - n = nodes.pop - lT.kn_graph.update({n: {n}}) - return lT.kn_graph + lT.kn_distances.update(out_distances) + return lT.kn_graph, lT.kn_distances def compute_spatial_edges( diff --git a/tests/test_lineageTree.py b/tests/test_lineageTree.py index bd0d54e..c7eb6f1 100644 --- a/tests/test_lineageTree.py +++ b/tests/test_lineageTree.py @@ -623,16 +623,22 @@ def test_spatial_density(): def test_compute_k_nearest_neighbours(): - assert lT1.compute_k_nearest_neighbours()[169994] == { - 108588, - 114722, - 129276, - 139163, - 148361, - 165681, - 169994, - 178396, - } + assert ( + lT1.compute_k_nearest_neighbours()[0][169994] + == [178396, 139163, 165681, 148361, 129276, 114722, 108588] + ).all() + assert np.allclose( + lT1.compute_k_nearest_neighbours()[1][169994], + [ + 34.39062611, + 50.72494649, + 58.97813932, + 71.48910824, + 80.50930278, + 171.6539282, + 173.50879512, + ], + ) def test_compute_spatial_edges():