-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathoctree.py
314 lines (258 loc) · 10.6 KB
/
octree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import random
import math
import numpy as np
import time
from result_set import KNNResultSet, RadiusNNResultSet
class Octant:
def __init__(self, children, center, extent, point_indices, is_leaf):
self.children = children
self.center = center
self.extent = extent
self.point_indices = point_indices
self.is_leaf = is_leaf
def __str__(self):
output = ''
output += 'center: [%.2f, %.2f, %.2f], ' % (self.center[0], self.center[1], self.center[2])
output += 'extent: %.2f, ' % self.extent
output += 'is_leaf: %d, ' % self.is_leaf
output += 'children: ' + str([x is not None for x in self.children]) + ", "
output += 'point_indices: ' + str(self.point_indices)
return output
def traverse_octree(root: Octant, depth, max_depth):
depth[0] += 1
if max_depth[0] < depth[0]:
max_depth[0] = depth[0]
if root is None:
pass
elif root.is_leaf:
print(root)
else:
for child in root.children:
traverse_octree(child, depth, max_depth)
depth[0] -= 1
def octree_recursive_build(root, db, center, extent, point_indices, leaf_size, min_extent):
if len(point_indices) == 0:
return None
if root is None:
root = Octant([None for i in range(8)], center, extent, point_indices, is_leaf=True)
# determine whether to split this octant
if len(point_indices) <= leaf_size or extent <= min_extent:
root.is_leaf = True
else:
root.is_leaf = False
children_point_indices = [[] for i in range(8)]
for point_idx in point_indices:
point_db = db[point_idx]
morton_code = 0
if point_db[0] > center[0]:
morton_code = morton_code | 1
if point_db[1] > center[1]:
morton_code = morton_code | 2
if point_db[2] > center[2]:
morton_code = morton_code | 4
children_point_indices[morton_code].append(point_idx)
# create children
factor = [-0.5, 0.5]
for i in range(8):
child_center_x = center[0] + factor[(i & 1) > 0] * extent
child_center_y = center[1] + factor[(i & 2) > 0] * extent
child_center_z = center[2] + factor[(i & 4) > 0] * extent
child_extent = 0.5 * extent
child_center = np.asarray([child_center_x, child_center_y, child_center_z])
root.children[i] = octree_recursive_build(root.children[i],
db,
child_center,
child_extent,
children_point_indices[i],
leaf_size,
min_extent)
return root
def inside(query: np.ndarray, radius: float, octant:Octant):
"""
Determines if the query ball is inside the octant
:param query:
:param radius:
:param octant:
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
possible_space = query_offset_abs + radius
return np.all(possible_space < octant.extent)
def overlaps(query: np.ndarray, radius: float, octant:Octant):
"""
Determines if the query ball overlaps with the octant
:param query:
:param radius:
:param octant:
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
# completely outside, since query is outside the relevant area
max_dist = radius + octant.extent
if np.any(query_offset_abs > max_dist):
return False
# if pass the above check, consider the case that the ball is contacting the face of the octant
if np.sum((query_offset_abs < octant.extent).astype(np.int)) >= 2:
return True
# conside the case that the ball is contacting the edge or corner of the octant
# since the case of the ball center (query) inside octant has been considered,
# we only consider the ball center (query) outside octant
x_diff = max(query_offset_abs[0] - octant.extent, 0)
y_diff = max(query_offset_abs[1] - octant.extent, 0)
z_diff = max(query_offset_abs[2] - octant.extent, 0)
return x_diff * x_diff + y_diff * y_diff + z_diff * z_diff < radius * radius
def contains(query: np.ndarray, radius: float, octant:Octant):
"""
Determine if the query ball contains the octant
:param query:
:param radius:
:param octant:
:return:
"""
query_offset = query - octant.center
query_offset_abs = np.fabs(query_offset)
query_offset_to_farthest_corner = query_offset_abs + octant.extent
return np.linalg.norm(query_offset_to_farthest_corner) < radius
def octree_radius_search_fast(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
if contains(query, result_set.worstDist(), root):
# compare the contents of the octant
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# don't need to check any child
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# no need to go to most relevant child first, because anyway we will go through all children
for c, child in enumerate(root.children):
if child is None:
continue
if False == overlaps(query, result_set.worstDist(), child):
continue
if octree_radius_search_fast(child, db, result_set, query):
return True
return inside(query, result_set.worstDist(), root)
def octree_radius_search(root: Octant, db: np.ndarray, result_set: RadiusNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# go to the relevant child first
morton_code = 0
if query[0] > root.center[0]:
morton_code = morton_code | 1
if query[1] > root.center[1]:
morton_code = morton_code | 2
if query[2] > root.center[2]:
morton_code = morton_code | 4
if octree_radius_search(root.children[morton_code], db, result_set, query):
return True
# check other children
for c, child in enumerate(root.children):
if c == morton_code or child is None:
continue
if False == overlaps(query, result_set.worstDist(), child):
continue
if octree_radius_search(child, db, result_set, query):
return True
# final check of if we can stop search
return inside(query, result_set.worstDist(), root)
def octree_knn_search(root: Octant, db: np.ndarray, result_set: KNNResultSet, query: np.ndarray):
if root is None:
return False
if root.is_leaf and len(root.point_indices) > 0:
# compare the contents of a leaf
leaf_points = db[root.point_indices, :]
diff = np.linalg.norm(np.expand_dims(query, 0) - leaf_points, axis=1)
for i in range(diff.shape[0]):
result_set.add_point(diff[i], root.point_indices[i])
# check whether we can stop search now
return inside(query, result_set.worstDist(), root)
# go to the relevant child first
morton_code = 0
if query[0] > root.center[0]:
morton_code = morton_code | 1
if query[1] > root.center[1]:
morton_code = morton_code | 2
if query[2] > root.center[2]:
morton_code = morton_code | 4
if octree_knn_search(root.children[morton_code], db, result_set, query):
return True
# check other children
for c, child in enumerate(root.children):
if c == morton_code or child is None:
continue
if False == overlaps(query, result_set.worstDist(), child):
continue
if octree_knn_search(child, db, result_set, query):
return True
# final check of if we can stop search
return inside(query, result_set.worstDist(), root)
def octree_construction(db_np, leaf_size, min_extent):
N, dim = db_np.shape[0], db_np.shape[1]
db_np_min = np.amin(db_np, axis=0)
db_np_max = np.amax(db_np, axis=0)
db_extent = np.max(db_np_max - db_np_min) * 0.5
db_center = db_np_min + db_extent
root = None
root = octree_recursive_build(root, db_np, db_center, db_extent, list(range(N)),
leaf_size, min_extent)
return root
def main():
# configuration
db_size = 64000
dim = 3
leaf_size = 4
min_extent = 0.0001
k = 8
db_np = np.random.rand(db_size, dim)
root = octree_construction(db_np, leaf_size, min_extent)
# depth = [0]
# max_depth = [0]
# traverse_octree(root, depth, max_depth)
# print("tree max depth: %d" % max_depth[0])
# query = np.asarray([0, 0, 0])
# result_set = KNNResultSet(capacity=k)
# octree_knn_search(root, db_np, result_set, query)
# print(result_set)
#
# diff = np.linalg.norm(np.expand_dims(query, 0) - db_np, axis=1)
# nn_idx = np.argsort(diff)
# nn_dist = diff[nn_idx]
# print(nn_idx[0:k])
# print(nn_dist[0:k])
begin_t = time.time()
print("Radius search normal:")
for i in range(100):
query = np.random.rand(3)
result_set = RadiusNNResultSet(radius=0.5)
octree_radius_search(root, db_np, result_set, query)
# print(result_set)
print("Search takes %.3fms\n" % ((time.time() - begin_t) * 1000))
begin_t = time.time()
print("Radius search fast:")
for i in range(100):
query = np.random.rand(3)
result_set = RadiusNNResultSet(radius = 0.5)
octree_radius_search_fast(root, db_np, result_set, query)
# print(result_set)
print("Search takes %.3fms\n" % ((time.time() - begin_t)*1000))
if __name__ == '__main__':
main()